From 0dea35c4e247265f9efd1472fde60867c81c028f Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Thu, 18 Apr 2024 04:32:38 -0400 Subject: [PATCH 01/45] fix cluster issue gpus-per-task --- mila/sbatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mila/sbatch.py b/mila/sbatch.py index f2cc71538a..ed8fa878d6 100644 --- a/mila/sbatch.py +++ b/mila/sbatch.py @@ -41,7 +41,7 @@ conda activate {env} fi {wandb_offline} -srun --gpus-per-task=1 --output={output} {python_command} +srun --output={output} {python_command} """ From 904fea19b457440d405e13321f3af68497dd9fe9 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Thu, 18 Apr 2024 08:35:34 -0400 Subject: [PATCH 02/45] new yaml configs --- configs/exps/deup/datasets/new-mc-faenet.yaml | 28 ++++++++++++ configs/exps/deup/gnn/depfaenet.yaml | 0 configs/exps/deup/gnn/faenet-training.yaml | 43 +++++++++++++++++++ configs/exps/deup/uncertainty/v0.yaml | 1 - configs/exps/deup/uncertainty/v1.yaml | 33 ++++++++++++++ 5 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 configs/exps/deup/datasets/new-mc-faenet.yaml create mode 100644 configs/exps/deup/gnn/depfaenet.yaml create mode 100644 configs/exps/deup/gnn/faenet-training.yaml create mode 100644 configs/exps/deup/uncertainty/v1.yaml diff --git a/configs/exps/deup/datasets/new-mc-faenet.yaml b/configs/exps/deup/datasets/new-mc-faenet.yaml new file mode 100644 index 0000000000..95aab7a29a --- /dev/null +++ b/configs/exps/deup/datasets/new-mc-faenet.yaml @@ -0,0 +1,28 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:1 + partition: long + +default: + config: faenet-is2re-all + wandb_project: ocp-deup + wandb_tags: base-model, MC-D, 2935198 + test_ri: True + mode: train + checkpoint: /network/scratch/a/alexandre.duval/ocp/runs/2935198/checkpoints/best_checkpoint.pt + restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/2935198 + model: + dropout_lowest_layer: output + first_trainable_layer: dropout + dropout_lin: 0.7 + cp_data_to_tmpdir: true + inference_time_loops: 1 + deup_dataset: + create: after # "before" -> created before training (for deup) "after" -> created after training (for is2re) "" - not created + dataset_strs: ["train", "val_id", "val_ood_cat", "val_ood_ads"] + n_samples: 7 + +runs: + - optim: + max_epochs: 12 diff --git a/configs/exps/deup/gnn/depfaenet.yaml b/configs/exps/deup/gnn/depfaenet.yaml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/configs/exps/deup/gnn/faenet-training.yaml b/configs/exps/deup/gnn/faenet-training.yaml new file mode 100644 index 0000000000..5e55752639 --- /dev/null +++ b/configs/exps/deup/gnn/faenet-training.yaml @@ -0,0 +1,43 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:1 + partition: long + time: 18:00:00 + +default: + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + wandb_tags: "top-model" + wandb_project: ocp-deup + optim: + batch_size: 256 + eval_batch_size: 256 + cp_data_to_tmpdir: True + +runs: + - config: faenet-is2re-all + note: "top-runs" + frame_averaging: 2D + fa_method: se3-random + model: + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 384 + num_filters: 480 + num_gaussians: 104 + num_interactions: 5 + second_layer_MLP: False + skip_co: concat + cutoff: 6.0 + optim: + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 12 + eval_every: 0.25 \ No newline at end of file diff --git a/configs/exps/deup/uncertainty/v0.yaml b/configs/exps/deup/uncertainty/v0.yaml index 94597ddaf3..4cdd6d8027 100644 --- a/configs/exps/deup/uncertainty/v0.yaml +++ b/configs/exps/deup/uncertainty/v0.yaml @@ -3,7 +3,6 @@ job: cpus: 4 gres: gpu:rtx8000:1 partition: long - code_loc: /home/mila/s/schmidtv/ocp-project/run-repos/ocp-3 default: config: deup_faenet-deup_is2re-all diff --git a/configs/exps/deup/uncertainty/v1.yaml b/configs/exps/deup/uncertainty/v1.yaml new file mode 100644 index 0000000000..4f69d7828f --- /dev/null +++ b/configs/exps/deup/uncertainty/v1.yaml @@ -0,0 +1,33 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:1 + partition: long + +default: + config: deup_faenet-deup_is2re-all + + wandb_project: ocp-deup + wandb_tags: base-model, MC-D, 3264530 + test_ri: True + mode: train + model: + dropout_lowest_layer: null + first_trainable_layer: output + dropout_lin: 0.7 + cp_data_to_tmpdir: false + inference_time_loops: 1 + restart_from_dir: /network/scratch/s/schmidtv/ocp/runs/3264530 + checkpoint: /network/scratch/s/schmidtv/ocp/runs/3264530 + dataset: # mandatory if restart_from_dir is set + default_val: deup-val_ood_cat-val_ood_ads + deup-train-val_id: + src: /network/scratch/s/schmidtv/ocp/runs/3264530/deup_dataset + deup-val_ood_cat-val_ood_ads: + src: /network/scratch/s/schmidtv/ocp/runs/3264530/deup_dataset + deup_dataset: + create: False + +runs: + - optim: + max_epochs: 12 From 7c481391ac6cbe5a77ee068518296e053dcf8930 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Thu, 18 Apr 2024 09:30:24 -0400 Subject: [PATCH 03/45] update path trained gnn model --- configs/exps/deup/datasets/new-mc-faenet.yaml | 6 +++--- ocpmodels/datasets/deup_dataset_creator.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/exps/deup/datasets/new-mc-faenet.yaml b/configs/exps/deup/datasets/new-mc-faenet.yaml index 95aab7a29a..56ea298687 100644 --- a/configs/exps/deup/datasets/new-mc-faenet.yaml +++ b/configs/exps/deup/datasets/new-mc-faenet.yaml @@ -7,11 +7,11 @@ job: default: config: faenet-is2re-all wandb_project: ocp-deup - wandb_tags: base-model, MC-D, 2935198 + wandb_tags: base-model, MC-D, 4615191 test_ri: True mode: train - checkpoint: /network/scratch/a/alexandre.duval/ocp/runs/2935198/checkpoints/best_checkpoint.pt - restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/2935198 + checkpoint: /network/scratch/a/alexandre.duval/scratch/ocp/runs/4615191/checkpoints/best_checkpoint.pt + restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4615191/ model: dropout_lowest_layer: output first_trainable_layer: dropout diff --git a/ocpmodels/datasets/deup_dataset_creator.py b/ocpmodels/datasets/deup_dataset_creator.py index 64d67fd164..4bc6a8bc0e 100644 --- a/ocpmodels/datasets/deup_dataset_creator.py +++ b/ocpmodels/datasets/deup_dataset_creator.py @@ -431,7 +431,7 @@ def write_lmdb(self, samples, path, total_size=-1, max_samples=-1): from ocpmodels.datasets.lmdb_dataset import DeupDataset from ocpmodels.common.utils import JOB_ID, RUNS_DIR, make_config_from_conf_str - base_trainer_path = "/network/scratch/s/schmidtv/ocp/runs/3298908" + base_trainer_path = "/network/scratch/a/alexandre.duval/ocp/runs/4615191" # what models to load for inference trainers_conf = { From 72ae772108c95dea3dc5b0501bb46bc813261fef Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Thu, 18 Apr 2024 09:30:41 -0400 Subject: [PATCH 04/45] fa_frames => fa_method --- configs/models/deup_faenet.yaml | 2 +- ocpmodels/datasets/deup_dataset_creator.py | 2 +- ocpmodels/datasets/lmdb_dataset.py | 6 +++--- scripts/train_density_estimator.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/configs/models/deup_faenet.yaml b/configs/models/deup_faenet.yaml index 2284687e2a..efa779c801 100644 --- a/configs/models/deup_faenet.yaml +++ b/configs/models/deup_faenet.yaml @@ -57,7 +57,7 @@ default: energy_coefficient: 1 frame_averaging: False # 2D, 3D, da, False - fa_frames: False # can be {None, full, random, det, e3, e3-random, e3-det} + fa_method: False # can be {None, full, random, det, e3, e3-random, e3-det} # ------------------- # ----- IS2RE ----- diff --git a/ocpmodels/datasets/deup_dataset_creator.py b/ocpmodels/datasets/deup_dataset_creator.py index 4bc6a8bc0e..1af5cdf8fd 100644 --- a/ocpmodels/datasets/deup_dataset_creator.py +++ b/ocpmodels/datasets/deup_dataset_creator.py @@ -167,7 +167,7 @@ def load_trainers(self, overrides={}): shared_config = {} shared_config["graph_rewiring"] = self.trainers[0].config["graph_rewiring"] - shared_config["fa_frames"] = self.trainers[0].config["fa_frames"] + shared_config["fa_method"] = self.trainers[0].config["fa_method"] shared_config["frame_averaging"] = self.trainers[0].config["frame_averaging"] # Done! diff --git a/ocpmodels/datasets/lmdb_dataset.py b/ocpmodels/datasets/lmdb_dataset.py index e4ea6bd7bc..8f8fb24442 100644 --- a/ocpmodels/datasets/lmdb_dataset.py +++ b/ocpmodels/datasets/lmdb_dataset.py @@ -37,7 +37,7 @@ class LmdbDataset(Dataset): config (dict): Dataset configuration transform (callable, optional): Data transform function. (default: :obj:`None`) - fa_frames (str, optional): type of frame averaging method applied, if any. + fa_method (str, optional): type of frame averaging method applied, if any. adsorbates (str, optional): comma-separated list of adsorbates to filter. If None or "all", no filtering is applied. (default: None) @@ -49,7 +49,7 @@ def __init__( self, config, transform=None, - fa_frames=None, + fa_method=None, lmdb_glob=None, adsorbates=None, adsorbates_ref_dir=None, @@ -96,7 +96,7 @@ def __init__( self.filter_per_adsorbates() self.transform = transform - self.fa_method = fa_frames + self.fa_method = fa_method def filter_per_adsorbates(self): """Filter the dataset to only include structures with a specific diff --git a/scripts/train_density_estimator.py b/scripts/train_density_estimator.py index a7a45b3273..b5f5bc4913 100644 --- a/scripts/train_density_estimator.py +++ b/scripts/train_density_estimator.py @@ -303,7 +303,7 @@ def validate(epoch, model, loader): "num_workers": 0, }, "frame_averaging": None, - "fa_frames": None, + "fa_method": None, "silent": False, "graph_rewiring": "remove-tag-0", "de": { From 040b475d479cd3a3ed7f48f1b357456fa4e8cf94 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Thu, 18 Apr 2024 09:54:07 -0400 Subject: [PATCH 05/45] skip_co = concat is not possible --- .../exps/deup/gnn/{depfaenet.yaml => depfaenet-training.yaml} | 0 configs/exps/deup/gnn/faenet-training.yaml | 4 ++-- ocpmodels/datasets/deup_dataset_creator.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) rename configs/exps/deup/gnn/{depfaenet.yaml => depfaenet-training.yaml} (100%) diff --git a/configs/exps/deup/gnn/depfaenet.yaml b/configs/exps/deup/gnn/depfaenet-training.yaml similarity index 100% rename from configs/exps/deup/gnn/depfaenet.yaml rename to configs/exps/deup/gnn/depfaenet-training.yaml diff --git a/configs/exps/deup/gnn/faenet-training.yaml b/configs/exps/deup/gnn/faenet-training.yaml index 5e55752639..8bf38ec5f2 100644 --- a/configs/exps/deup/gnn/faenet-training.yaml +++ b/configs/exps/deup/gnn/faenet-training.yaml @@ -18,7 +18,7 @@ default: runs: - config: faenet-is2re-all - note: "top-runs" + note: "top run no concat" frame_averaging: 2D fa_method: se3-random model: @@ -34,7 +34,7 @@ runs: num_gaussians: 104 num_interactions: 5 second_layer_MLP: False - skip_co: concat + skip_co: False cutoff: 6.0 optim: lr_initial: 0.002 diff --git a/ocpmodels/datasets/deup_dataset_creator.py b/ocpmodels/datasets/deup_dataset_creator.py index 1af5cdf8fd..b57522422e 100644 --- a/ocpmodels/datasets/deup_dataset_creator.py +++ b/ocpmodels/datasets/deup_dataset_creator.py @@ -306,6 +306,7 @@ def create_deup_dataset( stats = {d: {} for d in dataset_strs} + # Loop on train, val_id, val_ood_cat, val_ood_ads for dataset_name in dataset_strs: deup_samples = [] deup_ds_size = 0 From 83659c643d744a1f156b6c6c931930a58615a967 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Fri, 19 Apr 2024 05:55:16 -0400 Subject: [PATCH 06/45] Merge only relevant changed from disconnected_gnn branch, to run depfaenet --- configs/exps/catalyst/gflownet.yaml | 143 ++++++++++ configs/exps/catalyst/reproduce-configs.yaml | 75 +++++ configs/models/depfaenet.yaml | 271 +++++++++++++++++++ configs/models/painn.yaml | 3 + mila/sbatch.py | 11 +- ocpmodels/common/flags.py | 18 +- ocpmodels/datasets/data_transforms.py | 30 ++ ocpmodels/models/__init__.py | 1 + ocpmodels/models/base_model.py | 26 +- ocpmodels/models/depfaenet.py | 97 +++++++ ocpmodels/preprocessing/graph_rewiring.py | 5 + ocpmodels/trainers/base_trainer.py | 32 ++- ocpmodels/trainers/single_trainer.py | 8 +- scripts/debug_faenet.py | 222 +++++++++++++++ 14 files changed, 929 insertions(+), 13 deletions(-) create mode 100644 configs/exps/catalyst/gflownet.yaml create mode 100644 configs/exps/catalyst/reproduce-configs.yaml create mode 100644 configs/models/depfaenet.yaml create mode 100644 ocpmodels/models/depfaenet.py create mode 100644 scripts/debug_faenet.py diff --git a/configs/exps/catalyst/gflownet.yaml b/configs/exps/catalyst/gflownet.yaml new file mode 100644 index 0000000000..2432f47339 --- /dev/null +++ b/configs/exps/catalyst/gflownet.yaml @@ -0,0 +1,143 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + # wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + wandb_tags: "gflownet-model" + test_ri: True + mode: train + # graph_rewiring: remove-tag-0 + graph_rewiring: "" + frame_averaging: 2D + fa_method: se3-random + cp_data_to_tmpdir: True + is_disconnected: true + model: + edge_embed_type: all_rij + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 0 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: concat + cutoff: 4.0 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 9 + eval_every: 0.4 + +runs: + + # - config: faenet-is2re-all + # note: baseline faenet + + # - config: depfaenet-is2re-all + # note: depfaenet baseline + + # - config: depfaenet-is2re-all + # note: depfaenet per-adsorbate + # adsorbates: {'*O', '*OH', '*OH2', '*H'} + + # - config: depfaenet-is2re-all + # note: depfaenet per-adsorbate long string + # adsorbates: '*O, *OH, *OH2, *H' + + # - config: depfaenet-is2re-all + # note: depfaenet per-adsorbate string of a list + # adsorbates: "*O, *OH, *OH2, *H" + + # - config: depfaenet-is2re-all + # note: Trained on selected adsorbate more epochs + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 10 + + # - config: depfaenet-is2re-all + # note: depfaenet full data + + # - config: depfaenet-is2re-all + # note: To be used for continue from dir + + # - config: depfaenet-is2re-all + # note: Fine-tune on per-ads-dataset 4 epoch + # continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 4 + # lr_initial: 0.00015 + + # - config: depfaenet-is2re-all + # note: Fine-tune on per-ads-dataset 10 epoch + # continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 10 + # lr_initial: 0.00015 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset 10 epoch + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 20 + lr_initial: 0.0001 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset 20 epoch + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 20 + lr_initial: 0.00015 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset 15 epoch + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 15 + lr_initial: 0.0002 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset 10 epoch + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 10 + lr_initial: 0.0001 + + - config: depfaenet-is2re-all + note: Fine-tune on per-ads-dataset starting from fine-tuned model + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4071859 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 10 + lr_initial: 0.0001 + + - config: depfaenet-is2re-all + note: Trained on selected adsorbate + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 25 + lr_initial: 0.0001 + + - config: depfaenet-is2re-all + note: Trained on selected adsorbate + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 25 diff --git a/configs/exps/catalyst/reproduce-configs.yaml b/configs/exps/catalyst/reproduce-configs.yaml new file mode 100644 index 0000000000..c4c834585c --- /dev/null +++ b/configs/exps/catalyst/reproduce-configs.yaml @@ -0,0 +1,75 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:rtx8000:1 + partition: long + time: 15:00:00 + +default: + # wandb_name: alvaro-carbonero-math + wandb_project: ocp-alvaro + wandb_tags: "reproduce-best-config" + test_ri: True + mode: train + graph_rewiring: remove-tag-0 + note: "repoduce-top-run" + frame_averaging: 2D + fa_method: se3-random + cp_data_to_tmpdir: True + is_disconnected: true + model: + edge_embed_type: all_rij + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: concat + cutoff: 4.0 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 9 + eval_every: 0.4 + +runs: + + - config: faenet-is2re-all + note: baseline faenet + + - config: indfaenet-is2re-all + note: baseline with top configs + + - config: indfaenet-is2re-all + note: baseline with runs' configs + model: + tag_hidden_channels: 32 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 528 + num_filters: 672 + num_gaussians: 148 + num_interactions: 5 + second_layer_MLP: False + skip_co: concat + + - config: depfaenet-is2re-all + note: baseline with top configs + + - config: indfaenet-is2re-all + note: so that ads get old dimensions + model: + hidden_channels: 704 + num_gaussians: 200 + num_filters: 896 \ No newline at end of file diff --git a/configs/models/depfaenet.yaml b/configs/models/depfaenet.yaml new file mode 100644 index 0000000000..852ebc3bfd --- /dev/null +++ b/configs/models/depfaenet.yaml @@ -0,0 +1,271 @@ +default: + model: + name: depfaenet + act: swish + hidden_channels: 128 + num_filters: 100 + num_interactions: 3 + num_gaussians: 100 + cutoff: 6.0 + use_pbc: True + regress_forces: False + # drlab attributes: + tag_hidden_channels: 0 # 32 + pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels + phys_embeds: False # True + phys_hidden_channels: 0 + energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, pooling, graclus, random} + # faenet new features + skip_co: False # output skip connections {False, "add", "concat"} + second_layer_MLP: False # in EmbeddingBlock + complex_mp: False + edge_embed_type: rij # {'rij','all_rij','sh', 'all'}) + mp_type: base # {'base', 'simple', 'updownscale', 'att', 'base_with_att', 'local_env'} + graph_norm: False # bool + att_heads: 1 # int + force_decoder_type: "mlp" # can be {"" or "simple"} | only used if regress_forces is True + force_decoder_model_config: + simple: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + mlp: + hidden_channels: 256 + norm: batch1d # batch1d, layer or null + res: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + res_updown: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + optim: + batch_size: 64 + eval_batch_size: 64 + num_workers: 4 + lr_gamma: 0.1 + lr_initial: 0.001 + warmup_factor: 0.2 + max_epochs: 20 + energy_grad_coefficient: 10 + force_coefficient: 30 + energy_coefficient: 1 + + frame_averaging: False # 2D, 3D, da, False + fa_frames: False # can be {None, full, random, det, e3, e3-random, e3-det} + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +is2re: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + 100k: + model: + hidden_channels: 256 + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 18000 + - 27000 + - 37000 + warmup_steps: 6000 + max_epochs: 20 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +# For 2 GPUs + +s2ef: + default: + model: + num_interactions: 4 + hidden_channels: 750 + num_gaussians: 200 + num_filters: 256 + regress_forces: "direct" + force_coefficient: 30 + energy_grad_coefficient: 10 + optim: + batch_size: 96 + eval_batch_size: 96 + warmup_factor: 0.2 + lr_gamma: 0.1 + lr_initial: 0.0001 + max_epochs: 15 + warmup_steps: 30000 + lr_milestones: + - 55000 + - 75000 + - 10000 + + 200k: {} + + # 1 gpus + 2M: + model: + num_interactions: 5 + hidden_channels: 1024 + num_gaussians: 200 + num_filters: 256 + optim: + batch_size: 192 + eval_batch_size: 192 + + 20M: {} + + all: {} + +qm9: + default: + model: + act: swish + att_heads: 1 + complex_mp: true + cutoff: 6.0 + edge_embed_type: all_rij + energy_head: '' + graph_norm: true + graph_rewiring: null + hidden_channels: 400 + max_num_neighbors: 30 + mp_type: updownscale_base + num_filters: 480 + num_gaussians: 100 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: false + phys_hidden_channels: 0 + regress_forces: '' + second_layer_MLP: true + skip_co: true + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 64 + es_min_abs_change: 1.0e-06 + es_patience: 20 + es_warmup_epochs: 600 + eval_batch_size: 64 + factor: 0.9 + lr_initial: 0.0003 + loss_energy: mse + lr_gamma: 0.1 + lr_initial: 0.001 + max_epochs: 1500 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + patience: 15 + scheduler: ReduceLROnPlateau + threshold: 0.0001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + 10k: {} + all: {} + +qm7x: + default: + model: # SOTA settings + act: swish + att_heads: 1 + complex_mp: true + cutoff: 5.0 + edge_embed_type: all_rij + energy_head: false + force_decoder_model_config: + mlp: + hidden_channels: 256 + norm: batch1d + res: + hidden_channels: 128 + norm: batch1d + res_updown: + hidden_channels: 128 + norm: layer + simple: + hidden_channels: 128 + norm: batch1d + force_decoder_type: res_updown + graph_norm: false + hidden_channels: 500 + max_num_neighbors: 40 + mp_type: updownscale_base + num_filters: 400 + num_gaussians: 50 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: true + phys_hidden_channels: 0 + regress_forces: direct_with_gradient_target + second_layer_MLP: true + skip_co: false + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 100 + energy_grad_coefficient: 5 + eval_batch_size: 100 + eval_every: 0.34 + factor: 0.75 + force_coefficient: 75 + loss_energy: mae + loss_force: mse + lr_gamma: 0.1 + lr_initial: 0.000193 + max_steps: 4000000 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + scheduler: ReduceLROnPlateau + threshold: 0.001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + all: {} + 1k: {} + +qm9: + default: + model: + use_pbc: False + all: {} + 10k: {} diff --git a/configs/models/painn.yaml b/configs/models/painn.yaml index 2c0abac112..c138652a81 100644 --- a/configs/models/painn.yaml +++ b/configs/models/painn.yaml @@ -2,6 +2,9 @@ default: model: name: painn use_pbc: True + optim: + num_workers: 4 + eval_batch_size: 64 # ------------------- # ----- IS2RE ----- diff --git a/mila/sbatch.py b/mila/sbatch.py index ed8fa878d6..b6417adf1b 100644 --- a/mila/sbatch.py +++ b/mila/sbatch.py @@ -1,12 +1,13 @@ -from minydra import resolved_args, MinyDict -from pathlib import Path -from datetime import datetime import os +import re import subprocess -from shutil import copyfile import sys -import re +from datetime import datetime +from pathlib import Path +from shutil import copyfile + import yaml +from minydra import MinyDict, resolved_args IS_DRAC = ( "narval.calcul.quebec" in os.environ.get("HOSTNAME", "") diff --git a/ocpmodels/common/flags.py b/ocpmodels/common/flags.py index a6fbf20d02..761e61dac2 100644 --- a/ocpmodels/common/flags.py +++ b/ocpmodels/common/flags.py @@ -87,12 +87,14 @@ def add_core_args(self): "--checkpoint", type=str, help="Model checkpoint to load" ) self.parser.add_argument( - "--continue_from_dir", type=str, help="Run to continue, loading its config" + "--continue_from_dir", + type=str, + help="Continue an existing run, loading its config and overwriting desired arguments", ) self.parser.add_argument( "--restart_from_dir", type=str, - help="Run to restart, loading its config and overwriting " + help="Restart training from an existing run, loading its config and overwriting args" + "from the command-line", ) self.parser.add_argument( @@ -293,6 +295,18 @@ def add_core_args(self): help="Number of validation loops to run in order to collect inference" + " timing stats", ) + self.parser.add_argument( + "--is_disconnected", + type=bool, + default=False, + help="Eliminates edges between catalyst and adsorbate.", + ) + self.parser.add_argument( + "--lowest_energy_only", + type=bool, + default=False, + help="Makes trainer use the lowest energy data point for every (catalyst, adsorbate, cell) tuple. ONLY USE WITH ALL DATASET", + ) flags = Flags() diff --git a/ocpmodels/datasets/data_transforms.py b/ocpmodels/datasets/data_transforms.py index 6c26d2a9a9..17a63dfa52 100644 --- a/ocpmodels/datasets/data_transforms.py +++ b/ocpmodels/datasets/data_transforms.py @@ -127,6 +127,35 @@ def __call__(self, data): return self.rewiring_func(data) +class Disconnected(Transform): + def __init__(self, is_disconnected=False) -> None: + self.inactive = not is_disconnected + + def edge_classifier(self, edge_index, tags): + edges_with_tags = tags[ + edge_index.type(torch.long) + ] # Tensor with shape=edge_index.shape where every entry is a tag + filt1 = edges_with_tags[0] == edges_with_tags[1] + filt2 = (edges_with_tags[0] != 2) * (edges_with_tags[1] != 2) + + # Edge is removed if tags are different (R1), and at least one end has tag 2 (R2). We want ~(R1*R2) = ~R1+~R2. + # filt1 = ~R1. Let L1 be that head has tag 2, and L2 is that tail has tag 2. Then R2 = L1+L2, so ~R2 = ~L1*~L2 = filt2. + + return filt1 + filt2 + + def __call__(self, data): + if self.inactive: + return data + + values = self.edge_classifier(data.edge_index, data.tags) + + data.edge_index = data.edge_index[:, values] + data.cell_offsets = data.cell_offsets[values, :] + data.distances = data.distances[values] + + return data + + class Compose: # https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Compose def __init__(self, transforms): @@ -167,5 +196,6 @@ def get_transforms(trainer_config): AddAttributes(), GraphRewiring(trainer_config.get("graph_rewiring")), FrameAveraging(trainer_config["frame_averaging"], trainer_config["fa_method"]), + Disconnected(trainer_config["is_disconnected"]), ] return Compose(transforms) diff --git a/ocpmodels/models/__init__.py b/ocpmodels/models/__init__.py index a722f78170..c15c217b02 100644 --- a/ocpmodels/models/__init__.py +++ b/ocpmodels/models/__init__.py @@ -7,6 +7,7 @@ from .cgcnn import CGCNN # noqa: F401 from .dimenet import DimeNet # noqa: F401 from .faenet import FAENet # noqa: F401 +from .depfaenet import depFAENet # noqa: F401 from .gemnet.gemnet import GemNetT # noqa: F401 from .dimenet_plus_plus import DimeNetPlusPlus # noqa: F401 from .forcenet import ForceNet # noqa: F401 diff --git a/ocpmodels/models/base_model.py b/ocpmodels/models/base_model.py index 4a5c84a205..e2df0e7375 100644 --- a/ocpmodels/models/base_model.py +++ b/ocpmodels/models/base_model.py @@ -4,10 +4,12 @@ This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ + import logging import torch import torch.nn as nn +from torch_geometric.data import HeteroData from torch_geometric.nn import radius_graph from ocpmodels.common.utils import ( @@ -74,7 +76,14 @@ def forward(self, data, mode="train", regress_forces=None, q=None): # energy gradient w.r.t. positions will be computed if mode == "train" or self.regress_forces == "from_energy": - data.pos.requires_grad_(True) + if type(data) is list: + data[0].pos.requires_grad_(True) + data[1].pos.requires_grad_(True) + elif type(data[0]) is HeteroData: + data["adsorbate"].pos.requires_grad_(True) + data["catalyst"].pos.requires_grad_(True) + else: + data.pos.requires_grad_(True) # predict energy preds = self.energy_forward(data, q=q) @@ -85,7 +94,20 @@ def forward(self, data, mode="train", regress_forces=None, q=None): forces = self.forces_forward(preds) if mode == "train" or self.regress_forces == "from_energy": - grad_forces = self.forces_as_energy_grad(data.pos, preds["energy"]) + if ( + "gemnet" in self.__class__.__name__.lower() + and self.regress_forces == "from_energy" + ): + # gemnet forces are already computed + grad_forces = forces + else: + # compute forces from energy gradient + try: + grad_forces = self.forces_as_energy_grad( + data.pos, preds["energy"] + ) + except: + grad_forces = self.forces_as_energy_grad(data["adsorbate"].pos) if self.regress_forces == "from_energy": # predicted forces are the energy gradient diff --git a/ocpmodels/models/depfaenet.py b/ocpmodels/models/depfaenet.py new file mode 100644 index 0000000000..25f6a09683 --- /dev/null +++ b/ocpmodels/models/depfaenet.py @@ -0,0 +1,97 @@ +import torch +from torch.nn import Linear +from torch import nn +from torch_scatter import scatter + +from ocpmodels.models.faenet import FAENet +from ocpmodels.models.faenet import OutputBlock as conOutputBlock +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import conditional_grad +from ocpmodels.models.utils.activations import swish + +from torch_geometric.data import Batch + + +class discOutputBlock(conOutputBlock): + def __init__(self, energy_head, hidden_channels, act, disconnected_mlp=False): + super(discOutputBlock, self).__init__(energy_head, hidden_channels, act) + + # We modify the last output linear function to make the output a vector + self.lin2 = Linear(hidden_channels // 2, hidden_channels // 2) + + self.disconnected_mlp = disconnected_mlp + if self.disconnected_mlp: + self.ads_lin = Linear(hidden_channels // 2, hidden_channels // 2) + self.cat_lin = Linear(hidden_channels // 2, hidden_channels // 2) + + # Combines the hidden representation of each to a scalar. + self.combination = nn.Sequential( + Linear(hidden_channels // 2 * 2, hidden_channels // 2), + swish, + Linear(hidden_channels // 2, 1), + ) + + def tags_saver(self, tags): + self.current_tags = tags + + def forward(self, h, edge_index, edge_weight, batch, alpha): + if ( + self.energy_head == "weighted-av-final-embeds" + ): # Right now, this is the only available option. + alpha = self.w_lin(h) + + elif self.energy_head == "graclus": + h, batch = self.graclus(h, edge_index, edge_weight, batch) + + elif self.energy_head in {"pooling", "random"}: + h, batch, pooling_loss = self.hierarchical_pooling( + h, edge_index, edge_weight, batch + ) + + # MLP + h = self.lin1(h) + h = self.lin2(self.act(h)) + + if self.energy_head in { + "weighted-av-initial-embeds", + "weighted-av-final-embeds", + }: + h = h * alpha + + # We pool separately and then we concatenate. + ads = self.current_tags == 2 + cat = ~ads + + ads_out = scatter(h, batch * ads, dim=0, reduce="add") + cat_out = scatter(h, batch * cat, dim=0, reduce="add") + + if self.disconnected_mlp: + ads_out = self.ads_lin(ads_out) + cat_out = self.cat_lin(cat_out) + + system = torch.cat([ads_out, cat_out], dim=1) + + # Finally, we predict a number. + energy = self.combination(system) + + return energy + + +@registry.register_model("depfaenet") +class depFAENet(FAENet): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # We replace the old output block by the new output block + self.disconnected_mlp = kwargs.get("disconnected_mlp", False) + self.output_block = discOutputBlock( + self.energy_head, kwargs["hidden_channels"], self.act, self.disconnected_mlp + ) + + @conditional_grad(torch.enable_grad()) + def energy_forward(self, data): + # We need to save the tags so this step is necessary. + self.output_block.tags_saver(data.tags) + pred = super().energy_forward(data) + + return pred diff --git a/ocpmodels/preprocessing/graph_rewiring.py b/ocpmodels/preprocessing/graph_rewiring.py index 2f3b103a6c..b9115e9077 100644 --- a/ocpmodels/preprocessing/graph_rewiring.py +++ b/ocpmodels/preprocessing/graph_rewiring.py @@ -36,6 +36,11 @@ def remove_tag0_nodes(data): data.tags = data.tags[non_sub] if hasattr(data, "pos_relaxed"): data.pos_relaxed = data.pos_relaxed[non_sub, :] + if hasattr(data, "query"): + data.h = data.h[non_sub, :] + data.query = data.query[non_sub, :] + data.key = data.key[non_sub, :] + data.value = data.value[non_sub, :] # per-edge tensors data.edge_index = data.edge_index[:, neither_is_sub] diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index ea1537737d..e871027efe 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -8,6 +8,7 @@ import errno import logging import os +import pickle import random import time from abc import ABC, abstractmethod @@ -24,7 +25,7 @@ from rich.console import Console from rich.table import Table from torch.nn.parallel.distributed import DistributedDataParallel -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Subset from torch_geometric.data import Batch from tqdm import tqdm @@ -57,6 +58,7 @@ class BaseTrainer(ABC): def __init__(self, load=True, **kwargs): run_dir = kwargs["run_dir"] + model_name = kwargs["model"].pop( "name", kwargs.get("model_name", "Unknown - base_trainer issue") ) @@ -173,9 +175,21 @@ def __init__(self, load=True, **kwargs): ) (run_dir / f"config-{JOB_ID}.yaml").write_text(yaml.dump(self.config)) - if load: - self.load() + # Here's the models whose edges are removed as a transform + transform_models = [ + "depfaenet", + ] + if self.config["is_disconnected"]: + print("\n\nHeads up: cat-ads edges being removed!") + if self.config["model_name"] in transform_models: + if not self.config["is_disconnected"]: + print( + f"\n\nWhen using {self.config['model_name']},", + "the flag 'is_disconnected' should be used! The flag has been turned on.\n", + ) + self.config["is_disconnected"] = True + self.load() self.evaluator = Evaluator( task=self.task_name, model_regresses_forces=self.config["model"].get("regress_forces", ""), @@ -244,6 +258,7 @@ def get_dataloader(self, dataset, sampler): pin_memory=True, batch_sampler=sampler, ) + return loader def load_datasets(self): @@ -281,6 +296,16 @@ def load_datasets(self): silent=self.silent, ) + if self.config["lowest_energy_only"]: + with open( + "/network/scratch/a/alvaro.carbonero/lowest_energy.pkl", "rb" + ) as fp: + good_indices = pickle.load(fp) + good_indices = list(good_indices) + + self.real_dataset = self.datasets["train"] + self.datasets["train"] = Subset(self.datasets["train"], good_indices) + shuffle = False if "train" in split: shuffle = True @@ -402,6 +427,7 @@ def load_model(self): "task_name": self.task_name, }, **self.config["model"], + "model_name": self.config["model_name"], } self.model = registry.get_model_class(self.config["model_name"])( diff --git a/ocpmodels/trainers/single_trainer.py b/ocpmodels/trainers/single_trainer.py index c8850fe1af..25f82ec9a7 100644 --- a/ocpmodels/trainers/single_trainer.py +++ b/ocpmodels/trainers/single_trainer.py @@ -227,6 +227,8 @@ def train( # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. + if self.config["continue_from_dir"] is not None and self.config["adsorbates"] not in {None, "all"}: + self.step = 0 start_epoch = self.step // n_train max_epochs = self.config["optim"]["max_epochs"] timer = Times() @@ -498,7 +500,11 @@ def end_of_training( # Close datasets if debug_batches < 0: for ds in self.datasets.values(): - ds.close_db() + try: + ds.close_db() + except: + assert self.config["lowest_energy_only"] == True + self.real_dataset.close_db() def model_forward(self, batch_list, mode="train", q=None): """Perform a forward pass of the model when frame averaging is applied. diff --git a/scripts/debug_faenet.py b/scripts/debug_faenet.py new file mode 100644 index 0000000000..56d79c3d68 --- /dev/null +++ b/scripts/debug_faenet.py @@ -0,0 +1,222 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging +import os +import time +import traceback +import sys +import torch +from yaml import dump + +from ocpmodels.common import dist_utils +from ocpmodels.common.flags import flags +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import ( + JOB_ID, + auto_note, + build_config, + merge_dicts, + move_lmdb_data_to_slurm_tmpdir, + resolve, + setup_imports, + setup_logging, + update_from_sbatch_py_vars, + set_min_hidden_channels, +) +from ocpmodels.common.orion_utils import ( + continue_orion_exp, + load_orion_exp, + sample_orion_hparams, +) +from ocpmodels.trainers import BaseTrainer + +# os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +torch.multiprocessing.set_sharing_strategy("file_system") + + +def print_warnings(): + warnings = [ + "`max_num_neighbors` is set to 40. This should be tuned per model.", + "`tag_specific_weights` is not handled for " + + "`regress_forces: direct_with_gradient_target` in compute_loss()", + ] + print("\n" + "-" * 80 + "\n") + print("šŸ›‘ OCP-DR-Lab Warnings (nota benes):") + for warning in warnings: + print(f" • {warning}") + print("Remove warnings when they are fixed in the code/configs.") + print("\n" + "-" * 80 + "\n") + + +def wrap_up(args, start_time, error=None, signal=None, trainer=None): + total_time = time.time() - start_time + logging.info(f"Total time taken: {total_time}") + if trainer and trainer.logger is not None: + trainer.logger.log({"Total time": total_time}) + + if args.distributed: + print( + "\nWaiting for all processes to finish with dist_utils.cleanup()...", + end="", + ) + dist_utils.cleanup() + print("Done!") + + if "interactive" not in os.popen(f"squeue -hj {JOB_ID}").read(): + print("\nSelf-canceling SLURM job in 32s", JOB_ID) + os.popen(f"sleep 32 && scancel {JOB_ID}") + + if trainer and trainer.logger: + trainer.logger.finish(error or signal) + + +if __name__ == "__main__": + error = signal = orion_exp = orion_trial = trainer = None + orion_race_condition = False + hparams = {} + + setup_logging() + + parser = flags.get_parser() + args, override_args = parser.parse_known_args() + args = update_from_sbatch_py_vars(args) + if args.logdir: + args.logdir = resolve(args.logdir) + + # -- Build config + + args.wandb_name = "alvaro-carbonero-math" + args.wandb_project = "ocp-alvaro" + args.test_ri = True + args.mode = "train" + args.graph_rewiring = "remove-tag-0" + args.cp_data_to_tmpdir = True + args.config = "indfaenet-is2re-10k" + args.frame_averaging = "2D" + args.fa_frames = "se3-random" + + trainer_config = build_config(args, override_args) + + if dist_utils.is_master(): + trainer_config = move_lmdb_data_to_slurm_tmpdir(trainer_config) + dist_utils.synchronize() + + trainer_config["dataset"] = dist_utils.broadcast_from_master( + trainer_config["dataset"] + ) + + trainer_config["model"]["edge_embed_type"] = "all_rij" + trainer_config["model"]["mp_type"] = "updownscale" + trainer_config["model"]["phys_embeds"] = True + trainer_config["model"]["tag_hidden_channels"] = 32 + trainer_config["model"]["pg_hidden_channels"] = 64 + trainer_config["model"]["energy_head"] = "weighted-av-final-embeds" + trainer_config["model"]["complex_mp"] = False + trainer_config["model"]["graph_norm"] = True + trainer_config["model"]["hidden_channels"] = 352 + trainer_config["model"]["num_filters"] = 448 + trainer_config["model"]["num_gaussians"] = 99 + trainer_config["model"]["num_interactions"] = 6 + trainer_config["model"]["second_layer_MLP"] = True + trainer_config["model"]["skip_co"] = "concat" + # trainer_config["model"]["transformer_out"] = False + trainer_config["model"]["afaenet_gat_mode"] = "v1" + # trainer_config["model"]["disconnected_mlp"] = True + + # trainer_config["optim"]["batch_sizes"] = 256 + # trainer_config["optim"]["eval_batch_sizes"] = 256 + trainer_config["optim"]["lr_initial"] = 0.0019 + trainer_config["optim"]["scheduler"] = "LinearWarmupCosineAnnealingLR" + trainer_config["optim"]["max_epochs"] = 20 + trainer_config["optim"]["eval_every"] = 0.4 + + # -- Initial setup + + setup_imports() + print("\n🚩 All things imported.\n") + start_time = time.time() + + try: + # -- Orion + + if args.orion_exp_config_path and dist_utils.is_master(): + orion_exp = load_orion_exp(args) + hparams, orion_trial = sample_orion_hparams(orion_exp, trainer_config) + + if hparams.get("orion_race_condition"): + logging.warning("\n\n ā›”ļø Orion race condition. Stopping here.\n\n") + wrap_up(args, start_time, error, signal) + sys.exit() + + hparams = dist_utils.broadcast_from_master(hparams) + if hparams: + print("\nšŸ’Ž Received hyper-parameters from Orion:") + print(dump(hparams), end="\n") + trainer_config = merge_dicts(trainer_config, hparams) + + # -- Setup trainer + trainer_config = continue_orion_exp(trainer_config) + trainer_config = auto_note(trainer_config) + trainer_config = set_min_hidden_channels(trainer_config) + + try: + cls = registry.get_trainer_class(trainer_config["trainer"]) + trainer: BaseTrainer = cls(**trainer_config) + except Exception as e: + traceback.print_exc() + logging.warning(f"\nšŸ’€ Error in trainer initialization: {e}\n") + signal = "trainer_init_error" + + if signal is None: + task = registry.get_task_class(trainer_config["mode"])(trainer_config) + task.setup(trainer) + print_warnings() + + # -- Start Training + + signal = task.run() + + # -- End of training + + # handle job preemption / time limit + if signal == "SIGTERM": + print("\nJob was preempted. Wrapping up...\n") + if trainer: + trainer.close_datasets() + + dist_utils.synchronize() + + objective = dist_utils.broadcast_from_master( + trainer.objective if trainer else None + ) + + if orion_exp is not None: + if objective is None: + if signal == "loss_is_nan": + objective = 1e12 + print("Received NaN objective from worker. Setting to 1e12.") + if signal == "trainer_init_error": + objective = 1e12 + print( + "Received trainer_init_error from worker.", + "Setting objective to 1e12.", + ) + if objective is not None: + orion_exp.observe( + orion_trial, + [{"type": "objective", "name": "energy_mae", "value": objective}], + ) + else: + print("Received None objective from worker. Skipping observation.") + + except Exception: + error = True + print(traceback.format_exc()) + + finally: + wrap_up(args, start_time, error, signal, trainer=trainer) From 442ca59bcd788167e7c0533dde0a8b7c0e9a1770 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Fri, 19 Apr 2024 06:54:22 -0400 Subject: [PATCH 07/45] remove edge_embed_type --- configs/exps/catalyst/gflownet.yaml | 76 ++++++++++++------------ configs/exps/is2re/top-configs.yaml | 2 - configs/exps/orion/faenet-is2re-all.yaml | 2 - configs/exps/orion/faenet-qm9.yaml | 2 - configs/models/depfaenet.yaml | 3 - configs/models/deup_faenet.yaml | 3 - scripts/debug_faenet.py | 1 - scripts/test_all.py | 6 +- 8 files changed, 40 insertions(+), 55 deletions(-) diff --git a/configs/exps/catalyst/gflownet.yaml b/configs/exps/catalyst/gflownet.yaml index 2432f47339..4499b6e2a6 100644 --- a/configs/exps/catalyst/gflownet.yaml +++ b/configs/exps/catalyst/gflownet.yaml @@ -6,9 +6,8 @@ job: time: 15:00:00 default: - # wandb_name: alvaro-carbonero-math - wandb_project: ocp-alvaro - wandb_tags: "gflownet-model" + wandb_project: ocp-deup # ocp-alvaro + wandb_tags: gflownet-model, depfaenet test_ri: True mode: train # graph_rewiring: remove-tag-0 @@ -18,7 +17,6 @@ default: cp_data_to_tmpdir: True is_disconnected: true model: - edge_embed_type: all_rij mp_type: updownscale_base phys_embeds: True tag_hidden_channels: 0 @@ -89,55 +87,55 @@ runs: # max_epochs: 10 # lr_initial: 0.00015 - - config: depfaenet-is2re-all - note: Fine-tune on per-ads-dataset 10 epoch - continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 - adsorbates: "*O, *OH, *OH2, *H" - optim: - max_epochs: 20 - lr_initial: 0.0001 + # - config: depfaenet-is2re-all + # note: Fine-tune on per-ads-dataset 10 epoch + # continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 20 + # lr_initial: 0.0001 - - config: depfaenet-is2re-all - note: Fine-tune on per-ads-dataset 20 epoch - continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 - adsorbates: "*O, *OH, *OH2, *H" - optim: - max_epochs: 20 - lr_initial: 0.00015 + # - config: depfaenet-is2re-all + # note: Fine-tune on per-ads-dataset 20 epoch + # continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 20 + # lr_initial: 0.00015 - config: depfaenet-is2re-all - note: Fine-tune on per-ads-dataset 15 epoch + note: Depfaenet per-ads-dataset continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 adsorbates: "*O, *OH, *OH2, *H" optim: - max_epochs: 15 + max_epochs: 12 lr_initial: 0.0002 - config: depfaenet-is2re-all - note: Fine-tune on per-ads-dataset 10 epoch + note: Depfaenet per-ads-dataset continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 adsorbates: "*O, *OH, *OH2, *H" optim: max_epochs: 10 lr_initial: 0.0001 - - config: depfaenet-is2re-all - note: Fine-tune on per-ads-dataset starting from fine-tuned model - continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4071859 - adsorbates: "*O, *OH, *OH2, *H" - optim: - max_epochs: 10 - lr_initial: 0.0001 + # - config: depfaenet-is2re-all + # note: Fine-tune on per-ads-dataset starting from fine-tuned model + # continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4071859 + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 10 + # lr_initial: 0.0001 - - config: depfaenet-is2re-all - note: Trained on selected adsorbate - adsorbates: "*O, *OH, *OH2, *H" - optim: - max_epochs: 25 - lr_initial: 0.0001 + # - config: depfaenet-is2re-all + # note: Trained on selected adsorbate + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 25 + # lr_initial: 0.0001 - - config: depfaenet-is2re-all - note: Trained on selected adsorbate - adsorbates: "*O, *OH, *OH2, *H" - optim: - max_epochs: 25 + # - config: depfaenet-is2re-all + # note: Trained on selected adsorbate + # adsorbates: "*O, *OH, *OH2, *H" + # optim: + # max_epochs: 25 diff --git a/configs/exps/is2re/top-configs.yaml b/configs/exps/is2re/top-configs.yaml index cf4e79fe4c..6fa8826480 100644 --- a/configs/exps/is2re/top-configs.yaml +++ b/configs/exps/is2re/top-configs.yaml @@ -9,8 +9,6 @@ default: test_ri: True mode: train graph_rewiring: remove-tag-0 - model: - edge_embed_type: all_rij wandb_tags: "best-config" optim: batch_size: 256 diff --git a/configs/exps/orion/faenet-is2re-all.yaml b/configs/exps/orion/faenet-is2re-all.yaml index b3a1ccbca4..baecd59d92 100644 --- a/configs/exps/orion/faenet-is2re-all.yaml +++ b/configs/exps/orion/faenet-is2re-all.yaml @@ -14,8 +14,6 @@ default: wandb_tags: is2re-all, orion cp_data_to_tmpdir: true graph_rewiring: remove-tag-0 - model: - edge_embed_type: all_rij frame_averaging: 2D fa_method: random optim: diff --git a/configs/exps/orion/faenet-qm9.yaml b/configs/exps/orion/faenet-qm9.yaml index 2d26414fd2..722ed4472b 100644 --- a/configs/exps/orion/faenet-qm9.yaml +++ b/configs/exps/orion/faenet-qm9.yaml @@ -39,8 +39,6 @@ default: targets: hidden_channels, num_filters, pg_hidden_channels, phys_hidden_channels, batch_size frame_averaging: 3D fa_method: random - model: - edge_embed_type: all_rij orion: # Remember to change the experiment name if you change anything in the search space diff --git a/configs/models/depfaenet.yaml b/configs/models/depfaenet.yaml index 852ebc3bfd..da19d6e350 100644 --- a/configs/models/depfaenet.yaml +++ b/configs/models/depfaenet.yaml @@ -19,7 +19,6 @@ default: skip_co: False # output skip connections {False, "add", "concat"} second_layer_MLP: False # in EmbeddingBlock complex_mp: False - edge_embed_type: rij # {'rij','all_rij','sh', 'all'}) mp_type: base # {'base', 'simple', 'updownscale', 'att', 'base_with_att', 'local_env'} graph_norm: False # bool att_heads: 1 # int @@ -152,7 +151,6 @@ qm9: att_heads: 1 complex_mp: true cutoff: 6.0 - edge_embed_type: all_rij energy_head: '' graph_norm: true graph_rewiring: null @@ -205,7 +203,6 @@ qm7x: att_heads: 1 complex_mp: true cutoff: 5.0 - edge_embed_type: all_rij energy_head: false force_decoder_model_config: mlp: diff --git a/configs/models/deup_faenet.yaml b/configs/models/deup_faenet.yaml index efa779c801..bdc723bb55 100644 --- a/configs/models/deup_faenet.yaml +++ b/configs/models/deup_faenet.yaml @@ -25,7 +25,6 @@ default: skip_co: False # output skip connections {False, "add", "concat"} second_layer_MLP: False # in EmbeddingBlock complex_mp: False - edge_embed_type: rij # {'rij','all_rij','sh', 'all'}) mp_type: base # {'base', 'simple', 'updownscale', 'att', 'base_with_att', 'local_env'} graph_norm: False # bool att_heads: 1 # int @@ -153,7 +152,6 @@ qm9: att_heads: 1 complex_mp: true cutoff: 6.0 - edge_embed_type: all_rij energy_head: '' graph_norm: true graph_rewiring: null @@ -205,7 +203,6 @@ qm7x: att_heads: 1 complex_mp: true cutoff: 5.0 - edge_embed_type: all_rij energy_head: false force_decoder_model_config: mlp: diff --git a/scripts/debug_faenet.py b/scripts/debug_faenet.py index 56d79c3d68..6e55aef82b 100644 --- a/scripts/debug_faenet.py +++ b/scripts/debug_faenet.py @@ -110,7 +110,6 @@ def wrap_up(args, start_time, error=None, signal=None, trainer=None): trainer_config["dataset"] ) - trainer_config["model"]["edge_embed_type"] = "all_rij" trainer_config["model"]["mp_type"] = "updownscale" trainer_config["model"]["phys_embeds"] = True trainer_config["model"]["tag_hidden_channels"] = 32 diff --git a/scripts/test_all.py b/scripts/test_all.py index 783f6f302a..39d69b4a26 100644 --- a/scripts/test_all.py +++ b/scripts/test_all.py @@ -180,9 +180,9 @@ def isin(key, args): "--config=sfarinet-qm7x-1k --regress_forces=direct", "--config=sfarinet-qm7x-1k --regress_forces=direct_with_gradient_target", "--config=sfarinet-qm7x-1k --regress_forces=from_energy", - "--config=faenet-is2re-10k --model.edge_embed_type=rij --model.mp_type=base", - "--config=faenet-is2re-10k --model.edge_embed_type=all --model.mp_type=simple", - "--config=faenet-is2re-10k --model.edge_embed_type=sh --model.mp_type=updownscale", + "--config=faenet-is2re-10k --model.mp_type=base", + "--config=faenet-is2re-10k --model.mp_type=simple", + "--config=faenet-is2re-10k --model.mp_type=updownscale", # "--config=faenet-is2re-10k --model.edge_embed_type=all_rij --model.mp_type=local_env", # "--config=faenet-is2re-10k --model.mp_type=att", # "--config=faenet-is2re-10k --model.mp_type=base_with_att", From 0d70e8e9488b32de955d014f7ad89118030f1c83 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Fri, 19 Apr 2024 07:23:36 -0400 Subject: [PATCH 08/45] create deup-depfaenet, add dropout_lin, modif class names --- configs/exps/catalyst/gflownet.yaml | 2 +- ocpmodels/models/__init__.py | 2 +- ocpmodels/models/depfaenet.py | 30 ++++---- ocpmodels/models/deup_depfaenet.py | 105 ++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 17 deletions(-) create mode 100644 ocpmodels/models/deup_depfaenet.py diff --git a/configs/exps/catalyst/gflownet.yaml b/configs/exps/catalyst/gflownet.yaml index 4499b6e2a6..8dc46c189c 100644 --- a/configs/exps/catalyst/gflownet.yaml +++ b/configs/exps/catalyst/gflownet.yaml @@ -1,7 +1,7 @@ job: mem: 32GB cpus: 4 - gres: gpu:rtx8000:1 + gres: gpu:1 partition: long time: 15:00:00 diff --git a/ocpmodels/models/__init__.py b/ocpmodels/models/__init__.py index c15c217b02..9241e161f7 100644 --- a/ocpmodels/models/__init__.py +++ b/ocpmodels/models/__init__.py @@ -7,7 +7,7 @@ from .cgcnn import CGCNN # noqa: F401 from .dimenet import DimeNet # noqa: F401 from .faenet import FAENet # noqa: F401 -from .depfaenet import depFAENet # noqa: F401 +from .depfaenet import DepFAENet # noqa: F401 from .gemnet.gemnet import GemNetT # noqa: F401 from .dimenet_plus_plus import DimeNetPlusPlus # noqa: F401 from .forcenet import ForceNet # noqa: F401 diff --git a/ocpmodels/models/depfaenet.py b/ocpmodels/models/depfaenet.py index 25f6a09683..97d1979163 100644 --- a/ocpmodels/models/depfaenet.py +++ b/ocpmodels/models/depfaenet.py @@ -2,6 +2,7 @@ from torch.nn import Linear from torch import nn from torch_scatter import scatter +import torch.nn.functional as F from ocpmodels.models.faenet import FAENet from ocpmodels.models.faenet import OutputBlock as conOutputBlock @@ -12,9 +13,9 @@ from torch_geometric.data import Batch -class discOutputBlock(conOutputBlock): - def __init__(self, energy_head, hidden_channels, act, disconnected_mlp=False): - super(discOutputBlock, self).__init__(energy_head, hidden_channels, act) +class DiscOutputBlock(conOutputBlock): + def __init__(self, energy_head, hidden_channels, act, dropout_lin, disconnected_mlp=False): + super(DiscOutputBlock, self).__init__(energy_head, hidden_channels, act, dropout_lin) # We modify the last output linear function to make the output a vector self.lin2 = Linear(hidden_channels // 2, hidden_channels // 2) @@ -40,17 +41,16 @@ def forward(self, h, edge_index, edge_weight, batch, alpha): ): # Right now, this is the only available option. alpha = self.w_lin(h) - elif self.energy_head == "graclus": - h, batch = self.graclus(h, edge_index, edge_weight, batch) - - elif self.energy_head in {"pooling", "random"}: - h, batch, pooling_loss = self.hierarchical_pooling( - h, edge_index, edge_weight, batch - ) - # MLP + h = F.dropout( + h, p=self.dropout_lin, training=self.training or self.deup_inference + ) h = self.lin1(h) - h = self.lin2(self.act(h)) + h = self.act(h) + h = F.dropout( + h, p=self.dropout_lin, training=self.training or self.deup_inference + ) + h = self.lin2(h) if self.energy_head in { "weighted-av-initial-embeds", @@ -78,14 +78,14 @@ def forward(self, h, edge_index, edge_weight, batch, alpha): @registry.register_model("depfaenet") -class depFAENet(FAENet): +class DepFAENet(FAENet): def __init__(self, **kwargs): super().__init__(**kwargs) # We replace the old output block by the new output block self.disconnected_mlp = kwargs.get("disconnected_mlp", False) - self.output_block = discOutputBlock( - self.energy_head, kwargs["hidden_channels"], self.act, self.disconnected_mlp + self.output_block = DiscOutputBlock( + self.energy_head, kwargs["hidden_channels"], self.act, self.disconnected_mlp, self.dropout_lin, ) @conditional_grad(torch.enable_grad()) diff --git a/ocpmodels/models/deup_depfaenet.py b/ocpmodels/models/deup_depfaenet.py new file mode 100644 index 0000000000..8457acf45d --- /dev/null +++ b/ocpmodels/models/deup_depfaenet.py @@ -0,0 +1,105 @@ +import torch +from torch import nn +from torch.nn import Linear +from torch_scatter import scatter +from ocpmodels.common.registry import registry +from ocpmodels.models.depfaenet import DepFAENet, DiscOutputBlock + + +class DeupDepOutputBlock(DiscOutputBlock): + def __init__( + self, energy_head, hidden_channels, act, dropout_lin, deup_features={} + ): + super().__init__(energy_head, hidden_channels, act, dropout_lin) + + self.deup_features = deup_features + self.deup_data_keys = [f"deup_{k}" for k in deup_features] + self.deup_extra_dim = 0 + self._set_q_dim = False + + if "s" in deup_features: + self.deup_extra_dim += 1 + if "energy_pred_std" in deup_features: + self.deup_extra_dim += 1 + if "q" in deup_features: + self._set_q_dim = True + + if self.deup_extra_dim > 0: + self.deup_lin = Linear( + self.lin1.out_features + self.deup_extra_dim, self.lin1.out_features + ) + + def forward(self, h, edge_index, edge_weight, batch, alpha, data=None): + if self._set_q_dim: + assert data is not None + assert "deup_q" in data.to_dict().keys() + self.deup_extra_dim += data.deup_q.shape[-1] + self.deup_lin = Linear( + self.lin1.out_features + self.deup_extra_dim, self.lin1.out_features + ) + print("\nLazy loading deup extra dim from q. New dim:", self.deup_extra_dim) + print("āš ļø OutputBlock will be reinitialized.\n") + self.reset_parameters() + self._set_q_dim = False + + if self.energy_head == "weighted-av-final-embeds": + alpha = self.w_lin(h) + + # OutputBlock to get final atom rep + # No dropout in deup-(dep)faenet + h = self.lin1(h) + h = self.act(h) + if self.deup_extra_dim <= 0: + h = self.lin2(h) + + if self.energy_head in { + "weighted-av-initial-embeds", + "weighted-av-final-embeds", + }: + h = h * alpha + + # Global pooling -- get final graph rep + out = scatter( + h, + batch, + dim=0, + reduce="mean" if self.deup_extra_dim > 0 else "add", + ) + + # Concat graph representation with deup features (s, kde(q), std) + # and apply MLPs + if self.deup_extra_dim > 0: + assert data is not None + data_keys = set(data.to_dict().keys()) + assert all(dk in data_keys for dk in self.deup_data_keys), ( + f"Some deup data keys ({self.deup_data_keys}) are missing" + + f" from the data dict ({data_keys})" + ) + out = torch.cat( + [out] + + [data[f"deup_{k}"][:, None].float() for k in self.deup_features], + dim=-1, + ) + out = self.deup_lin(out) + out = self.act(out) + out = self.lin2(out) + + return out + +@registry.register_model("deup_depfaenet") +class DeupFAENet(DepFAENet): + def __init__(self, *args, **kwargs): + kwargs["dropout_edge"] = 0 + super().__init__(*args, **kwargs) + self.output_block = DeupDepOutputBlock( + self.energy_head, + kwargs["hidden_channels"], + self.act, + self.dropout_lin, + kwargs.get("deup_features", {}), + ) + assert ( + self.energy_head != "weighted-av-initial-embeds" + ), "Unsupported head weighted-av-initial-embeds" + assert self.skip_co != "concat", "Unsupported skip connection concat" + assert self.skip_co != "add", "Unsupported skip connection add" \ No newline at end of file From fd9d1d1524a6d661ca2a1f882f3c315f3125640d Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Fri, 19 Apr 2024 07:31:42 -0400 Subject: [PATCH 09/45] add q --- ocpmodels/models/depfaenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocpmodels/models/depfaenet.py b/ocpmodels/models/depfaenet.py index 97d1979163..af3da682d0 100644 --- a/ocpmodels/models/depfaenet.py +++ b/ocpmodels/models/depfaenet.py @@ -89,7 +89,7 @@ def __init__(self, **kwargs): ) @conditional_grad(torch.enable_grad()) - def energy_forward(self, data): + def energy_forward(self, data, q=None): # We need to save the tags so this step is necessary. self.output_block.tags_saver(data.tags) pred = super().energy_forward(data) From 2ab5c335b22903e958efb97f825a6f009d65ef28 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Fri, 19 Apr 2024 08:00:59 -0400 Subject: [PATCH 10/45] fix forward of output block depfaenet --- ocpmodels/models/depfaenet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocpmodels/models/depfaenet.py b/ocpmodels/models/depfaenet.py index af3da682d0..87f76e08c1 100644 --- a/ocpmodels/models/depfaenet.py +++ b/ocpmodels/models/depfaenet.py @@ -35,7 +35,7 @@ def __init__(self, energy_head, hidden_channels, act, dropout_lin, disconnected_ def tags_saver(self, tags): self.current_tags = tags - def forward(self, h, edge_index, edge_weight, batch, alpha): + def forward(self, h, edge_index, edge_weight, batch, alpha, data): if ( self.energy_head == "weighted-av-final-embeds" ): # Right now, this is the only available option. @@ -85,7 +85,7 @@ def __init__(self, **kwargs): # We replace the old output block by the new output block self.disconnected_mlp = kwargs.get("disconnected_mlp", False) self.output_block = DiscOutputBlock( - self.energy_head, kwargs["hidden_channels"], self.act, self.disconnected_mlp, self.dropout_lin, + self.energy_head, kwargs["hidden_channels"], self.act, self.dropout_lin, self.disconnected_mlp, ) @conditional_grad(torch.enable_grad()) From 9f18bfd2244e06a7a4e1c9eb3927223c2ed3df6e Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Tue, 23 Apr 2024 07:41:59 -0400 Subject: [PATCH 11/45] new model checkpoints to create deup-dataset --- ...c-faenet.yaml => data-with-depfaenet.yaml} | 6 ++-- configs/exps/deup/datasets/mc-faenet.yaml | 28 +++++++++++++++++++ ...c-faenet.yaml => old-train-mc-faenet.yaml} | 0 3 files changed, 31 insertions(+), 3 deletions(-) rename configs/exps/deup/datasets/{new-mc-faenet.yaml => data-with-depfaenet.yaml} (80%) create mode 100644 configs/exps/deup/datasets/mc-faenet.yaml rename configs/exps/deup/datasets/{train-mc-faenet.yaml => old-train-mc-faenet.yaml} (100%) diff --git a/configs/exps/deup/datasets/new-mc-faenet.yaml b/configs/exps/deup/datasets/data-with-depfaenet.yaml similarity index 80% rename from configs/exps/deup/datasets/new-mc-faenet.yaml rename to configs/exps/deup/datasets/data-with-depfaenet.yaml index 56ea298687..8c7d4a00e6 100644 --- a/configs/exps/deup/datasets/new-mc-faenet.yaml +++ b/configs/exps/deup/datasets/data-with-depfaenet.yaml @@ -7,11 +7,11 @@ job: default: config: faenet-is2re-all wandb_project: ocp-deup - wandb_tags: base-model, MC-D, 4615191 + wandb_tags: depfaenet, MC-D,4621042 test_ri: True mode: train - checkpoint: /network/scratch/a/alexandre.duval/scratch/ocp/runs/4615191/checkpoints/best_checkpoint.pt - restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4615191/ + checkpoint: /network/scratch/a/alexandre.duval/ocp/runs/4621042/checkpoints/best_checkpoint.pt + restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4621042/ model: dropout_lowest_layer: output first_trainable_layer: dropout diff --git a/configs/exps/deup/datasets/mc-faenet.yaml b/configs/exps/deup/datasets/mc-faenet.yaml new file mode 100644 index 0000000000..8069e3573e --- /dev/null +++ b/configs/exps/deup/datasets/mc-faenet.yaml @@ -0,0 +1,28 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:1 + partition: long + +default: + config: faenet-is2re-all + wandb_project: ocp-deup + wandb_tags: base-model, MC-D, 4616500 + test_ri: True + mode: train + checkpoint: /network/scratch/a/alexandre.duval/ocp/runs/4616500/checkpoints/best_checkpoint.pt + restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4616500/ + model: + dropout_lowest_layer: output + first_trainable_layer: dropout + dropout_lin: 0.7 + cp_data_to_tmpdir: true + inference_time_loops: 1 + deup_dataset: + create: after # "before" -> created before training (for deup) "after" -> created after training (for is2re) "" - not created + dataset_strs: ["train", "val_id", "val_ood_cat", "val_ood_ads"] + n_samples: 7 + +runs: + - optim: + max_epochs: 12 diff --git a/configs/exps/deup/datasets/train-mc-faenet.yaml b/configs/exps/deup/datasets/old-train-mc-faenet.yaml similarity index 100% rename from configs/exps/deup/datasets/train-mc-faenet.yaml rename to configs/exps/deup/datasets/old-train-mc-faenet.yaml From e0fb6f7738c746e205921ba9916166ea4f18a519 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Tue, 23 Apr 2024 08:34:27 -0400 Subject: [PATCH 12/45] argparse deup_dataset + comments --- ocpmodels/datasets/deup_dataset_creator.py | 29 ++++++++++++++++++---- scripts/deup_dataset.sh | 11 ++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 scripts/deup_dataset.sh diff --git a/ocpmodels/datasets/deup_dataset_creator.py b/ocpmodels/datasets/deup_dataset_creator.py index b57522422e..f3a4e3adc8 100644 --- a/ocpmodels/datasets/deup_dataset_creator.py +++ b/ocpmodels/datasets/deup_dataset_creator.py @@ -228,6 +228,7 @@ def _structure(preds): if self.mc_dropout: if n_samples <= 0: raise ValueError("n_samples must be > 0 for MC-Dropout ensembles.") + # Speed up computation by re-using latent representation q for all models preds += [ self.trainers[0].model_forward(batch_list, mode="deup", q=q) for _ in range(n_samples - len(preds)) @@ -320,12 +321,14 @@ def create_deup_dataset( preds = self.forward( batch_list, n_samples=n_samples, shared_encoder=True ) - + # Compute mean and standard deviation of GNN predictions pred_mean = preds["energies"].mean(dim=1) # Batch pred_std = preds["energies"].std(dim=1) # Batch + # Compute residual between mean predicted energy and ground truth loss = self.trainers[0].loss_fn["energy"]( pred_mean, batch.y_relaxed.to(pred_mean.device) ) + # Store deup samples deup_samples += [ { "energy_target": batch.y_relaxed.clone(), @@ -431,13 +434,29 @@ def write_lmdb(self, samples, path, total_size=-1, max_samples=-1): from ocpmodels.datasets.deup_dataset_creator import DeupDatasetCreator from ocpmodels.datasets.lmdb_dataset import DeupDataset from ocpmodels.common.utils import JOB_ID, RUNS_DIR, make_config_from_conf_str + import argparse + + def parse_args(): + parser = argparse.ArgumentParser(description="Deup Dataset Creator") + parser.add_argument( + "--checkpoints", + nargs="+", + default="/network/scratch/a/alexandre.duval/ocp/runs/4616500/", + help="Paths to the checkpoints", + ) + parser.add_argument( + "--dropout", + type=float, + default=0.2, + help="Dropout value", + ) + return parser.parse_args() - base_trainer_path = "/network/scratch/a/alexandre.duval/ocp/runs/4615191" + args = parse_args() - # what models to load for inference trainers_conf = { - "checkpoints": [base_trainer_path], - "dropout": 0.7, + "checkpoints": args.checkpoints, + "dropout": args.dropout, } # setting first_trainable_layer to output means that the latent space # q will be defined as input to the output layer, even though the model diff --git a/scripts/deup_dataset.sh b/scripts/deup_dataset.sh new file mode 100644 index 0000000000..d42384a058 --- /dev/null +++ b/scripts/deup_dataset.sh @@ -0,0 +1,11 @@ +#!/bin/bash +#SBATCH --job-name=deup-dataset +#SBATCH --ntasks=1 +#SBATCH --mem=32GB +#SBATCH --gres=gpu:1 +#SBATCH --output="/network/scratch/a/alexandre.duval/ocp/runs/output-%j.txt" # replace: location where you want to store the output of the job + +module load anaconda/3 # replace: load anaconda module +conda activate ocp # replace: conda env name +cd /home/mila/a/alexandre.duval/ocp/ocp # replace: location of the code +python -m ocpmodels.datasets.deup_dataset_creator \ No newline at end of file From 5b9c76f1973df00ae333333e71eadfc9f3af2053 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Wed, 24 Apr 2024 05:18:45 -0400 Subject: [PATCH 13/45] fix chkpt_path + edge case error + new configs --- .../deup/datasets/data-with-depfaenet.yaml | 4 +- configs/exps/deup/gnn/depfaenet-training.yaml | 57 ++++++++++++++ configs/exps/deup/gnn/faenet-training.yaml | 7 +- configs/exps/deup/gnn/pretrain-depfaenet.yaml | 78 +++++++++++++++++++ ocpmodels/tasks/task.py | 8 +- ocpmodels/trainers/single_trainer.py | 34 +++++--- 6 files changed, 168 insertions(+), 20 deletions(-) create mode 100644 configs/exps/deup/gnn/pretrain-depfaenet.yaml diff --git a/configs/exps/deup/datasets/data-with-depfaenet.yaml b/configs/exps/deup/datasets/data-with-depfaenet.yaml index 8c7d4a00e6..e329beff87 100644 --- a/configs/exps/deup/datasets/data-with-depfaenet.yaml +++ b/configs/exps/deup/datasets/data-with-depfaenet.yaml @@ -5,7 +5,7 @@ job: partition: long default: - config: faenet-is2re-all + config: depfaenet-is2re-all wandb_project: ocp-deup wandb_tags: depfaenet, MC-D,4621042 test_ri: True @@ -15,7 +15,7 @@ default: model: dropout_lowest_layer: output first_trainable_layer: dropout - dropout_lin: 0.7 + dropout_lin: 0.3 cp_data_to_tmpdir: true inference_time_loops: 1 deup_dataset: diff --git a/configs/exps/deup/gnn/depfaenet-training.yaml b/configs/exps/deup/gnn/depfaenet-training.yaml index e69de29bb2..d81ac5d384 100644 --- a/configs/exps/deup/gnn/depfaenet-training.yaml +++ b/configs/exps/deup/gnn/depfaenet-training.yaml @@ -0,0 +1,57 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:1 + partition: long + time: 15:00:00 + +default: + wandb_project: ocp-deup + wandb_tags: depfaenet, no-concat, with-tag0, dropout + test_ri: True + mode: train + graph_rewiring: "" + frame_averaging: 2D + fa_method: se3-random + cp_data_to_tmpdir: True + is_disconnected: true + model: + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 0 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: False + cutoff: 4.0 + dropout_lin: 0.3 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + eval_every: 0.4 + +runs: + + - config: depfaenet-is2re-all + note: Depfaenet per-ads-dataset + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 10 + lr_initial: 0.0002 + + - config: depfaenet-is2re-all + note: Depfaenet per-ads-dataset + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 12 + lr_initial: 0.0001 diff --git a/configs/exps/deup/gnn/faenet-training.yaml b/configs/exps/deup/gnn/faenet-training.yaml index 8bf38ec5f2..0d6aa34d51 100644 --- a/configs/exps/deup/gnn/faenet-training.yaml +++ b/configs/exps/deup/gnn/faenet-training.yaml @@ -8,8 +8,7 @@ job: default: test_ri: True mode: train - graph_rewiring: remove-tag-0 - wandb_tags: "top-model" + wandb_tags: faenet, no-concat, with-tag0, dropout wandb_project: ocp-deup optim: batch_size: 256 @@ -36,8 +35,10 @@ runs: second_layer_MLP: False skip_co: False cutoff: 6.0 + dropout_lin: 0.3 + dropout_lowest_layer: output optim: lr_initial: 0.002 scheduler: LinearWarmupCosineAnnealingLR - max_epochs: 12 + max_epochs: 14 eval_every: 0.25 \ No newline at end of file diff --git a/configs/exps/deup/gnn/pretrain-depfaenet.yaml b/configs/exps/deup/gnn/pretrain-depfaenet.yaml new file mode 100644 index 0000000000..83029997d4 --- /dev/null +++ b/configs/exps/deup/gnn/pretrain-depfaenet.yaml @@ -0,0 +1,78 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:1 + partition: long + time: 15:00:00 + +default: + wandb_project: ocp-deup + wandb_tags: gflownet-model, depfaenet + test_ri: True + mode: train + graph_rewiring: "" + frame_averaging: 2D + fa_method: se3-random + cp_data_to_tmpdir: True + is_disconnected: true + model: + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 0 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: False + cutoff: 4.0 + dropout_lin: 0.3 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + eval_every: 0.4 + +runs: + + - config: depfaenet-is2re-all + note: Depfaenet pre-train + dropout + optim: + max_epochs: 12 + lr_initial: 0.0002 + + - config: depfaenet-is2re-all + note: Depfaenet pre-train + dropout + optim: + max_epochs: 10 + lr_initial: 0.0001 + + - config: depfaenet-is2re-all + note: depfaenet with top configs + dropout + model: + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: False + cutoff: 4.0 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 9 + eval_every: 0.4 \ No newline at end of file diff --git a/ocpmodels/tasks/task.py b/ocpmodels/tasks/task.py index 8a9e3d8be3..c3c938eec0 100644 --- a/ocpmodels/tasks/task.py +++ b/ocpmodels/tasks/task.py @@ -27,10 +27,10 @@ def setup(self, trainer): self.trainer.load_checkpoint(self.config["checkpoint"]) print() - # save checkpoint path to runner state for slurm resubmissions - self.chkpt_path = os.path.join( - self.trainer.config["checkpoint_dir"], "checkpoint.pt" - ) + # save checkpoint path to runner state for slurm resubmissions + self.chkpt_path = os.path.join( + self.trainer.config["checkpoint_dir"], "checkpoint.pt" + ) def run(self): raise NotImplementedError diff --git a/ocpmodels/trainers/single_trainer.py b/ocpmodels/trainers/single_trainer.py index 25f82ec9a7..b9e2f921cf 100644 --- a/ocpmodels/trainers/single_trainer.py +++ b/ocpmodels/trainers/single_trainer.py @@ -227,7 +227,11 @@ def train( # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. - if self.config["continue_from_dir"] is not None and self.config["adsorbates"] not in {None, "all"}: + if ( + "continue_from_dir" in self.config + and self.config["continue_from_dir"] is not None + and self.config["adsorbates"] not in {None, "all"} + ): self.step = 0 start_epoch = self.step // n_train max_epochs = self.config["optim"]["max_epochs"] @@ -589,11 +593,15 @@ def compute_loss(self, preds, batch_list): # Energy loss energy_target = torch.cat( [ - batch.y_relaxed.to(self.device) - if self.task_name == "is2re" - else batch.deup_loss.to(self.device) - if self.task_name == "deup_is2re" - else batch.y.to(self.device) + ( + batch.y_relaxed.to(self.device) + if self.task_name == "is2re" + else ( + batch.deup_loss.to(self.device) + if self.task_name == "deup_is2re" + else batch.y.to(self.device) + ) + ) for batch in batch_list ], dim=0, @@ -706,11 +714,15 @@ def compute_metrics( target = { "energy": torch.cat( [ - batch.y_relaxed.to(self.device) - if self.task_name == "is2re" - else batch.deup_loss.to(self.device) - if self.task_name == "deup_is2re" - else batch.y.to(self.device) + ( + batch.y_relaxed.to(self.device) + if self.task_name == "is2re" + else ( + batch.deup_loss.to(self.device) + if self.task_name == "deup_is2re" + else batch.y.to(self.device) + ) + ) for batch in batch_list ], dim=0, From 8599de7e1b56b8a0a590fd6c9212396b358d3fe1 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Wed, 24 Apr 2024 07:57:56 -0400 Subject: [PATCH 14/45] adapt configs for v0 deup-faenet training on deup-dataset --- .../exps/deup/uncertainty/faenet_test.yaml | 31 ++ configs/exps/deup/uncertainty/v1.yaml | 13 +- configs/models/deup_depfaenet.yaml | 273 ++++++++++++++++++ configs/models/deup_faenet.yaml | 53 ++-- configs/models/tasks/deup_is2re.yaml | 2 +- 5 files changed, 340 insertions(+), 32 deletions(-) create mode 100644 configs/exps/deup/uncertainty/faenet_test.yaml create mode 100644 configs/models/deup_depfaenet.yaml diff --git a/configs/exps/deup/uncertainty/faenet_test.yaml b/configs/exps/deup/uncertainty/faenet_test.yaml new file mode 100644 index 0000000000..19742c31cc --- /dev/null +++ b/configs/exps/deup/uncertainty/faenet_test.yaml @@ -0,0 +1,31 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:1 + partition: long + +default: + config: deup_faenet-deup_is2re-all + wandb_project: ocp-deup + wandb_tags: faenet, MC-D, 4616500-model, 4642835-dataset + test_ri: True + mode: train + model: + dropout_lowest_layer: null + first_trainable_layer: output + dropout_lin: 0.3 + cp_data_to_tmpdir: false + inference_time_loops: 1 + # restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4621042/ + # checkpoint: /network/scratch/a/alexandre.duval/ocp/runs/4621042/ + dataset: # mandatory if restart_from_dir is set + default_val: deup-val_ood_cat-val_ood_ads + deup-train-val_id: + src: /network/scratch/a/alexandre.duval/ocp/runs/4642835/deup_dataset + deup-val_ood_cat-val_ood_ads: + src: /network/scratch/a/alexandre.duval/ocp/runs/4642835/deup_dataset + deup_dataset: + create: False + +runs: + - note: deup-faenet d=0.2 (not trained with d) \ No newline at end of file diff --git a/configs/exps/deup/uncertainty/v1.yaml b/configs/exps/deup/uncertainty/v1.yaml index 4f69d7828f..1f6a064c24 100644 --- a/configs/exps/deup/uncertainty/v1.yaml +++ b/configs/exps/deup/uncertainty/v1.yaml @@ -6,25 +6,24 @@ job: default: config: deup_faenet-deup_is2re-all - wandb_project: ocp-deup - wandb_tags: base-model, MC-D, 3264530 + wandb_tags: base-model, MC-D, 4616500-model, 4642835-dataset test_ri: True mode: train model: dropout_lowest_layer: null first_trainable_layer: output - dropout_lin: 0.7 + dropout_lin: 0.3 cp_data_to_tmpdir: false inference_time_loops: 1 - restart_from_dir: /network/scratch/s/schmidtv/ocp/runs/3264530 - checkpoint: /network/scratch/s/schmidtv/ocp/runs/3264530 + restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4621042/ + # checkpoint: /network/scratch/a/alexandre.duval/ocp/runs/4621042/ dataset: # mandatory if restart_from_dir is set default_val: deup-val_ood_cat-val_ood_ads deup-train-val_id: - src: /network/scratch/s/schmidtv/ocp/runs/3264530/deup_dataset + src: /network/scratch/a/alexandre.duval/ocp/runs/4642835/deup_dataset deup-val_ood_cat-val_ood_ads: - src: /network/scratch/s/schmidtv/ocp/runs/3264530/deup_dataset + src: /network/scratch/a/alexandre.duval/ocp/runs/4642835/deup_dataset deup_dataset: create: False diff --git a/configs/models/deup_depfaenet.yaml b/configs/models/deup_depfaenet.yaml new file mode 100644 index 0000000000..24ab2587c9 --- /dev/null +++ b/configs/models/deup_depfaenet.yaml @@ -0,0 +1,273 @@ +default: + model: + name: deup_depfaenet + act: swish + dropout_lin: 0.0 + dropout_edge: 0.0 + dropout_lowest_layer: output # lowest layer where `dropout_lin` is applied. Can be `inter-{i}` or `output`. Defaults to `output`. + first_trainable_layer: dropout # lowest layer to NOT freeze. All previous layers will be frozen. Can be ``, `embed`, `inter-{i}`, `output`, or `dropout`. + # if it is `` then no layer is frozen. If it is `dropout` then it will be set to the layer before `dropout_lowest_layer`. + # Defaults to ``. + hidden_channels: 384 + num_filters: 480 + num_interactions: 5 + num_gaussians: 104 + cutoff: 6.0 + use_pbc: True + regress_forces: False + tag_hidden_channels: 64 # only for OC20 + pg_hidden_channels: 64 # period & group embedding hidden channels + phys_embeds: True # physics-aware embeddings for atoms + phys_hidden_channels: 0 + energy_head: weighted-av-final-embeds # Energy head: {False, weighted-av-initial-embeds, weighted-av-final-embeds} + skip_co: False # Skip connections {False, "add", "concat"} + second_layer_MLP: False # in EmbeddingBlock + complex_mp: True # 2-layer MLP in Interaction blocks + mp_type: base # Message Passing type {'base', 'simple', 'updownscale', 'updownscale_base'} + graph_norm: True # graph normalization layer + force_decoder_type: "mlp" # force head (`"simple"`, `"mlp"`, `"res"`, `"res_updown"`) + force_decoder_model_config: + simple: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + mlp: + hidden_channels: 256 + norm: batch1d # batch1d, layer or null + res: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + res_updown: + hidden_channels: 128 + norm: batch1d # batch1d, layer or null + deup_features: [s, energy_pred_std] + optim: + batch_size: 256 + eval_batch_size: 256 + max_epochs: 12 + scheduler: LinearWarmupCosineAnnealingLR + optimizer: AdamW + num_workers: 4 + warmup_steps: 6000 + warmup_factor: 0.2 + lr_initial: 0.002 + lr_gamma: 0.1 + energy_grad_coefficient: 10 + force_coefficient: 30 + energy_coefficient: 1 + lr_milestones: + - 18000 + - 27000 + - 37000 + epoch_fine_tune: 4 + + frame_averaging: "" # 2D, 3D, da, False + fa_method: "" # can be {None, full, random, det, e3, e3-random, e3-det} + +# ------------------- +# ----- IS2RE ----- +# ------------------- + +deup_is2re: # was: is2re + 10k: + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + 100k: + model: + hidden_channels: 256 + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 18000 + - 27000 + - 37000 + warmup_steps: 6000 + max_epochs: 20 + +# ------------------ +# ----- S2EF ----- +# ------------------ + +# For 2 GPUs + +s2ef: + default: + model: + num_interactions: 4 + hidden_channels: 750 + num_gaussians: 200 + num_filters: 256 + regress_forces: "direct" + optim: + batch_size: 96 + eval_batch_size: 96 + warmup_factor: 0.2 + lr_gamma: 0.1 + lr_initial: 0.0001 + max_epochs: 15 + warmup_steps: 30000 + lr_milestones: + - 55000 + - 75000 + - 10000 + + 200k: {} + + # 1 gpus + 2M: + model: + num_interactions: 5 + hidden_channels: 1024 + num_gaussians: 200 + num_filters: 256 + optim: + batch_size: 192 + eval_batch_size: 192 + + 20M: {} + + all: {} + +qm9: + default: + model: + act: swish + att_heads: 1 + complex_mp: true + cutoff: 6.0 + energy_head: '' + graph_norm: true + graph_rewiring: null + hidden_channels: 400 + max_num_neighbors: 30 + mp_type: updownscale_base + num_filters: 480 + num_gaussians: 100 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: false + phys_hidden_channels: 0 + regress_forces: '' + second_layer_MLP: true + skip_co: true + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 64 + es_min_abs_change: 1.0e-06 + es_patience: 20 + es_warmup_epochs: 600 + eval_batch_size: 64 + factor: 0.9 + loss_energy: mse + lr_gamma: 0.1 + lr_initial: 0.0003 + max_epochs: 1500 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + patience: 15 + scheduler: ReduceLROnPlateau + threshold: 0.0001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + 10k: {} + all: {} + +qm7x: + default: + model: # SOTA settings + act: swish + att_heads: 1 + complex_mp: true + cutoff: 5.0 + energy_head: false + force_decoder_model_config: + mlp: + hidden_channels: 256 + norm: batch1d + res: + hidden_channels: 128 + norm: batch1d + res_updown: + hidden_channels: 128 + norm: layer + simple: + hidden_channels: 128 + norm: batch1d + force_decoder_type: res_updown + graph_norm: false + hidden_channels: 500 + max_num_neighbors: 40 + mp_type: updownscale_base + num_filters: 400 + num_gaussians: 50 + num_interactions: 5 + otf_graph: false + pg_hidden_channels: 32 + phys_embeds: true + phys_hidden_channels: 0 + regress_forces: direct_with_gradient_target + second_layer_MLP: true + skip_co: false + tag_hidden_channels: 0 + use_pbc: false + + optim: + batch_size: 100 + energy_grad_coefficient: 5 + eval_batch_size: 100 + eval_every: 0.34 + factor: 0.75 + force_coefficient: 75 + loss_energy: mae + loss_force: mse + lr_gamma: 0.1 + lr_initial: 0.000193 + max_steps: 4000000 + min_lr: 1.0e-06 + mode: min + optimizer: AdamW + scheduler: ReduceLROnPlateau + threshold: 0.001 + threshold_mode: abs + verbose: true + warmup_factor: 0.2 + warmup_steps: 3000 + + all: {} + 1k: {} + +qm9: + default: + model: + use_pbc: False + all: {} + 10k: {} diff --git a/configs/models/deup_faenet.yaml b/configs/models/deup_faenet.yaml index bdc723bb55..f6e52681f1 100644 --- a/configs/models/deup_faenet.yaml +++ b/configs/models/deup_faenet.yaml @@ -8,27 +8,24 @@ default: first_trainable_layer: dropout # lowest layer to NOT freeze. All previous layers will be frozen. Can be ``, `embed`, `inter-{i}`, `output`, or `dropout`. # if it is `` then no layer is frozen. If it is `dropout` then it will be set to the layer before `dropout_lowest_layer`. # Defaults to ``. - hidden_channels: 128 - num_filters: 100 - num_interactions: 3 - num_gaussians: 100 + hidden_channels: 384 + num_filters: 480 + num_interactions: 5 + num_gaussians: 104 cutoff: 6.0 use_pbc: True regress_forces: False - # drlab attributes: - tag_hidden_channels: 0 # 32 - pg_hidden_channels: 0 # 32 -> period & group embedding hidden channels - phys_embeds: False # True + tag_hidden_channels: 64 # only for OC20 + pg_hidden_channels: 64 # period & group embedding hidden channels + phys_embeds: True # physics-aware embeddings for atoms phys_hidden_channels: 0 - energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds} - # faenet new features - skip_co: False # output skip connections {False, "add", "concat"} + energy_head: weighted-av-final-embeds # Energy head: {False, weighted-av-initial-embeds, weighted-av-final-embeds} + skip_co: False # Skip connections {False, "add", "concat"} second_layer_MLP: False # in EmbeddingBlock - complex_mp: False - mp_type: base # {'base', 'simple', 'updownscale', 'att', 'base_with_att', 'local_env'} - graph_norm: False # bool - att_heads: 1 # int - force_decoder_type: "mlp" # can be {"" or "simple"} | only used if regress_forces is True + complex_mp: True # 2-layer MLP in Interaction blocks + mp_type: base # Message Passing type {'base', 'simple', 'updownscale', 'updownscale_base'} + graph_norm: True # graph normalization layer + force_decoder_type: "mlp" # force head (`"simple"`, `"mlp"`, `"res"`, `"res_updown"`) force_decoder_model_config: simple: hidden_channels: 128 @@ -44,19 +41,27 @@ default: norm: batch1d # batch1d, layer or null deup_features: [s, energy_pred_std] optim: - batch_size: 64 - eval_batch_size: 64 + batch_size: 256 + eval_batch_size: 256 + max_epochs: 12 + scheduler: LinearWarmupCosineAnnealingLR + optimizer: AdamW num_workers: 4 - lr_gamma: 0.1 - lr_initial: 0.001 + warmup_steps: 6000 warmup_factor: 0.2 - max_epochs: 20 - energy_grad_coefficient: 5 + lr_initial: 0.002 + lr_gamma: 0.1 + energy_grad_coefficient: 10 force_coefficient: 30 energy_coefficient: 1 + lr_milestones: + - 18000 + - 27000 + - 37000 + epoch_fine_tune: 4 - frame_averaging: False # 2D, 3D, da, False - fa_method: False # can be {None, full, random, det, e3, e3-random, e3-det} + frame_averaging: "" # 2D, 3D, da, False + fa_method: "" # can be {None, full, random, det, e3, e3-random, e3-det} # ------------------- # ----- IS2RE ----- diff --git a/configs/models/tasks/deup_is2re.yaml b/configs/models/tasks/deup_is2re.yaml index 65aab2e311..fa85d99ab9 100644 --- a/configs/models/tasks/deup_is2re.yaml +++ b/configs/models/tasks/deup_is2re.yaml @@ -41,7 +41,7 @@ default: n_samples: 7 ensemble_checkpoints: /network/scratch/a/alexandre.duval/ocp/runs/2935198 - ensemble_dropout: 0.7 + ensemble_dropout: 0.3 10k: From 58b992727c63e6156698c6dabdfb5c3d250f3d4a Mon Sep 17 00:00:00 2001 From: Christina Date: Thu, 25 Apr 2024 07:53:11 -0400 Subject: [PATCH 15/45] fix module load --- ocpmodels/common/utils.py | 18 +++++++++++------- ocpmodels/models/__init__.py | 6 +++++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index af7cddc22d..99185a39af 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -755,7 +755,7 @@ def add_edge_distance_to_graph( # Copied from https://github.com/facebookresearch/mmf/blob/master/mmf/utils/env.py#L89. -def setup_imports(): +def setup_imports(skip_modules=[]): from ocpmodels.common.registry import registry try: @@ -803,10 +803,14 @@ def setup_imports(): splits = f.split(os.sep) file_name = splits[-1] module_name = file_name[: file_name.find(".py")] - importlib.import_module("ocpmodels.%s.%s" % (key[1:], module_name)) + if module_name not in skip_modules: + importlib.import_module("ocpmodels.%s.%s" % (key[1:], module_name)) # manual model imports - importlib.import_module("ocpmodels.models.gemnet_oc.gemnet_oc") + try: + importlib.import_module("ocpmodels.models.gemnet_oc.gemnet_oc") + except: + print("unable to load gemnet_oc") experimental_folder = os.path.join(root_folder, "../experimental/") if os.path.exists(experimental_folder): @@ -1797,7 +1801,7 @@ def make_script_trainer(str_args=[], overrides={}, silent=False, mode="train"): return trainer -def make_config_from_dir(path, mode, overrides={}, silent=None): +def make_config_from_dir(path, mode, overrides={}, silent=None, setup_imports=[]): """ Make a config from a directory. This is useful when restarting or continuing from a previous run. @@ -1834,11 +1838,11 @@ def make_config_from_dir(path, mode, overrides={}, silent=None): config = build_config(default_args, silent=silent) config = merge_dicts(config, overrides) - setup_imports() + setup_imports(setup_imports=setup_imports) return config -def make_trainer_from_dir(path, mode, overrides={}, silent=None): +def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]): """ Make a trainer from a directory. @@ -1854,7 +1858,7 @@ def make_trainer_from_dir(path, mode, overrides={}, silent=None): Returns: Trainer: The loaded trainer. """ - config = make_config_from_dir(path, mode, overrides, silent) + config = make_config_from_dir(path, mode, overrides, silent, skip_imports) return registry.get_trainer_class(config["trainer"])(**config) diff --git a/ocpmodels/models/__init__.py b/ocpmodels/models/__init__.py index 9241e161f7..8a56eaea40 100644 --- a/ocpmodels/models/__init__.py +++ b/ocpmodels/models/__init__.py @@ -8,7 +8,11 @@ from .dimenet import DimeNet # noqa: F401 from .faenet import FAENet # noqa: F401 from .depfaenet import DepFAENet # noqa: F401 -from .gemnet.gemnet import GemNetT # noqa: F401 + +try: + from .gemnet.gemnet import GemNetT # noqa: F401 +except: + print("unable to load gemnet") from .dimenet_plus_plus import DimeNetPlusPlus # noqa: F401 from .forcenet import ForceNet # noqa: F401 from .schnet import SchNet # noqa: F401 From 5a5524c59f467d406ab5fffc2dc391c685b18441 Mon Sep 17 00:00:00 2001 From: Christina Date: Thu, 25 Apr 2024 08:14:50 -0400 Subject: [PATCH 16/45] return hidden state in wrapper --- ocpmodels/common/gfn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 95257359f3..0a1f45521c 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -107,6 +107,7 @@ def forward( self, batch: Union[Batch, Data, List[Data], List[Batch]], preprocess: bool = True, + retrieve_hidden: bool = False, ): """Perform a forward pass of the model when frame averaging is applied. @@ -162,6 +163,8 @@ def forward( if preds["energy"].shape[-1] == 1: preds["energy"] = preds["energy"].view(-1) + if retrieve_hidden: + return preds return preds["energy"] # denormalize? def freeze(self): From 6594960177e4f3b14853ac62ffe146ef6092d5af Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Thu, 25 Apr 2024 09:49:10 -0400 Subject: [PATCH 17/45] `scatter` `q` in `energy_forward` --- ocpmodels/models/faenet.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ocpmodels/models/faenet.py b/ocpmodels/models/faenet.py index dc7dd1efd6..7d5c6044da 100644 --- a/ocpmodels/models/faenet.py +++ b/ocpmodels/models/faenet.py @@ -7,17 +7,17 @@ import torch.nn.functional as F from torch import nn from torch.nn import Embedding, Linear -from torch_geometric.utils import dropout_edge from torch_geometric.nn import MessagePassing, radius_graph from torch_geometric.nn.norm import GraphNorm +from torch_geometric.utils import dropout_edge from torch_scatter import scatter from ocpmodels.common.registry import registry +from ocpmodels.common.utils import conditional_grad, get_pbc_distances from ocpmodels.models.base_model import BaseModel from ocpmodels.models.force_decoder import ForceDecoder from ocpmodels.models.utils.activations import swish from ocpmodels.modules.phys_embeddings import PhysEmbedding -from ocpmodels.common.utils import get_pbc_distances, conditional_grad class GaussianSmearing(nn.Module): @@ -751,6 +751,9 @@ def energy_forward(self, data, q=None): q = h.clone().detach() else: + # WARNING + # q which is NOT the hidden state h if it was stored as a scattered + # version of h. This works for GPs, NOT for MC-dropout h = q alpha = None @@ -763,6 +766,9 @@ def energy_forward(self, data, q=None): elif self.skip_co == "add": energy = sum(energy_skip_co) + if q and len(q) > len(energy): + q = scatter(q, batch, dim=0, reduce="mean") # N_graphs x hidden_channels + preds = { "energy": energy, "hidden_state": h, From 5ce2f3f7c074aa063238d153817afca1418be35f Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Thu, 25 Apr 2024 14:03:17 -0400 Subject: [PATCH 18/45] fix configs for depfaenet/faenet fine-tuning --- ...-training.yaml => depfaenet-finetune.yaml} | 46 +++++++++++-------- configs/exps/deup/gnn/faenet-finetune.yaml | 46 +++++++++++++++++++ 2 files changed, 74 insertions(+), 18 deletions(-) rename configs/exps/deup/gnn/{depfaenet-training.yaml => depfaenet-finetune.yaml} (78%) create mode 100644 configs/exps/deup/gnn/faenet-finetune.yaml diff --git a/configs/exps/deup/gnn/depfaenet-training.yaml b/configs/exps/deup/gnn/depfaenet-finetune.yaml similarity index 78% rename from configs/exps/deup/gnn/depfaenet-training.yaml rename to configs/exps/deup/gnn/depfaenet-finetune.yaml index d81ac5d384..1da2e29bc6 100644 --- a/configs/exps/deup/gnn/depfaenet-training.yaml +++ b/configs/exps/deup/gnn/depfaenet-finetune.yaml @@ -13,6 +13,33 @@ default: graph_rewiring: "" frame_averaging: 2D fa_method: se3-random + is_disconnected: true + +runs: + + - config: depfaenet-is2re-all + note: Depfaenet per-ads-dataset + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4647488 #4647466 #4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 10 + lr_initial: 0.0002 + + - config: depfaenet-is2re-all + note: Depfaenet per-ads-dataset + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4647488 #4647466 # 4023244 + adsorbates: "*O, *OH, *OH2, *H" + optim: + max_epochs: 12 + lr_initial: 0.0001 + +- config: depfaenet-is2re-all + note: Depfaenet per-ads-dataset + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4647466 # 4023244 + adsorbates: "*O, *OH, *OH2, *H" + graph_rewiring: "" + frame_averaging: 2D + fa_method: se3-random cp_data_to_tmpdir: True is_disconnected: true model: @@ -37,21 +64,4 @@ default: lr_initial: 0.002 scheduler: LinearWarmupCosineAnnealingLR eval_every: 0.4 - -runs: - - - config: depfaenet-is2re-all - note: Depfaenet per-ads-dataset - continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 - adsorbates: "*O, *OH, *OH2, *H" - optim: - max_epochs: 10 - lr_initial: 0.0002 - - - config: depfaenet-is2re-all - note: Depfaenet per-ads-dataset - continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4023244 - adsorbates: "*O, *OH, *OH2, *H" - optim: - max_epochs: 12 - lr_initial: 0.0001 + max_epochs: 12 \ No newline at end of file diff --git a/configs/exps/deup/gnn/faenet-finetune.yaml b/configs/exps/deup/gnn/faenet-finetune.yaml new file mode 100644 index 0000000000..fae81269a0 --- /dev/null +++ b/configs/exps/deup/gnn/faenet-finetune.yaml @@ -0,0 +1,46 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:1 + partition: long + time: 18:00:00 + +default: + test_ri: True + mode: train + wandb_tags: faenet, no-concat, with-tag0, dropout, fine-tuned + wandb_project: ocp-deup + optim: + batch_size: 256 + eval_batch_size: 256 + cp_data_to_tmpdir: True + +runs: + - config: faenet-is2re-all + note: "fine-tuned faenet" + continue_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4647489 + adsorbates: "*O, *OH, *OH2, *H" + frame_averaging: 2D + fa_method: se3-random + model: + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 384 + num_filters: 480 + num_gaussians: 104 + num_interactions: 5 + second_layer_MLP: False + skip_co: False + cutoff: 6.0 + dropout_lin: 0.3 + dropout_lowest_layer: output + optim: + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 14 + eval_every: 0.25 \ No newline at end of file From 606fcd07ae72aa3b80d48abce3fc62ed13593dbc Mon Sep 17 00:00:00 2001 From: Christina Date: Fri, 26 Apr 2024 02:33:56 -0400 Subject: [PATCH 19/45] quickfixes --- ocpmodels/common/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index 99185a39af..b7a39d391b 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -755,7 +755,7 @@ def add_edge_distance_to_graph( # Copied from https://github.com/facebookresearch/mmf/blob/master/mmf/utils/env.py#L89. -def setup_imports(skip_modules=[]): +def setup_imports(skip_imports=[]): from ocpmodels.common.registry import registry try: @@ -803,7 +803,7 @@ def setup_imports(skip_modules=[]): splits = f.split(os.sep) file_name = splits[-1] module_name = file_name[: file_name.find(".py")] - if module_name not in skip_modules: + if module_name not in skip_imports: importlib.import_module("ocpmodels.%s.%s" % (key[1:], module_name)) # manual model imports @@ -1801,7 +1801,7 @@ def make_script_trainer(str_args=[], overrides={}, silent=False, mode="train"): return trainer -def make_config_from_dir(path, mode, overrides={}, silent=None, setup_imports=[]): +def make_config_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]): """ Make a config from a directory. This is useful when restarting or continuing from a previous run. @@ -1838,7 +1838,7 @@ def make_config_from_dir(path, mode, overrides={}, silent=None, setup_imports=[] config = build_config(default_args, silent=silent) config = merge_dicts(config, overrides) - setup_imports(setup_imports=setup_imports) + setup_imports(skip_imports=skip_imports) return config From fcf265055d8f9c7f6ce146670eed92655af94f8c Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Fri, 26 Apr 2024 03:30:53 -0400 Subject: [PATCH 20/45] update configs deup-depfaenet --- configs/exps/deup/gnn/faenet-finetune.yaml | 5 +- configs/exps/deup/gnn/faenet-training.yaml | 5 +- .../exps/deup/uncertainty/deup_depfaenet.yaml | 61 +++++++++++++++++++ ocpmodels/datasets/deup_dataset_creator.py | 4 +- scripts/gnn_dev.py | 9 ++- 5 files changed, 73 insertions(+), 11 deletions(-) create mode 100644 configs/exps/deup/uncertainty/deup_depfaenet.yaml diff --git a/configs/exps/deup/gnn/faenet-finetune.yaml b/configs/exps/deup/gnn/faenet-finetune.yaml index fae81269a0..de51808345 100644 --- a/configs/exps/deup/gnn/faenet-finetune.yaml +++ b/configs/exps/deup/gnn/faenet-finetune.yaml @@ -10,9 +10,10 @@ default: mode: train wandb_tags: faenet, no-concat, with-tag0, dropout, fine-tuned wandb_project: ocp-deup + graph_rewiring: "" optim: - batch_size: 256 - eval_batch_size: 256 + batch_size: 232 + eval_batch_size: 232 cp_data_to_tmpdir: True runs: diff --git a/configs/exps/deup/gnn/faenet-training.yaml b/configs/exps/deup/gnn/faenet-training.yaml index 0d6aa34d51..c3775e66e5 100644 --- a/configs/exps/deup/gnn/faenet-training.yaml +++ b/configs/exps/deup/gnn/faenet-training.yaml @@ -11,13 +11,14 @@ default: wandb_tags: faenet, no-concat, with-tag0, dropout wandb_project: ocp-deup optim: - batch_size: 256 - eval_batch_size: 256 + batch_size: 200 + eval_batch_size: 200 cp_data_to_tmpdir: True runs: - config: faenet-is2re-all note: "top run no concat" + graph_rewiring: "" frame_averaging: 2D fa_method: se3-random model: diff --git a/configs/exps/deup/uncertainty/deup_depfaenet.yaml b/configs/exps/deup/uncertainty/deup_depfaenet.yaml new file mode 100644 index 0000000000..7b3ccd8a1e --- /dev/null +++ b/configs/exps/deup/uncertainty/deup_depfaenet.yaml @@ -0,0 +1,61 @@ +job: + mem: 32GB + cpus: 4 + gres: gpu:1 + partition: long + +default: + config: deup_faenet-deup_is2re-all + wandb_project: ocp-deup + wandb_tags: deup-depfaenet, 4648581-model, 4657270-dataset + test_ri: True + mode: train + model: + dropout_lowest_layer: output + first_trainable_layer: output + dropout_lin: 0.3 + cp_data_to_tmpdir: false + inference_time_loops: 1 + restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4648581/ + checkpoint: /network/scratch/a/alexandre.duval/ocp/runs/4648581/ + dataset: # mandatory if restart_from_dir is set + default_val: deup-val_ood_cat-val_ood_ads + deup-train-val_id: + src: /network/scratch/a/alexandre.duval/ocp/runs/4657270/deup_dataset + deup-val_ood_cat-val_ood_ads: + src: /network/scratch/a/alexandre.duval/ocp/runs/4657270/deup_dataset + deup_dataset: + create: False + +runs: + - note: deup-depfaenet (with dropout) + graph_rewiring: "" + frame_averaging: 2D + fa_method: se3-random + cp_data_to_tmpdir: True + is_disconnected: true + model: + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 0 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: False + cutoff: 4.0 + dropout_lin: 0.3 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + eval_every: 0.4 + max_epochs: 12 + + - note: deup-depfaenet (without specifying configs) \ No newline at end of file diff --git a/ocpmodels/datasets/deup_dataset_creator.py b/ocpmodels/datasets/deup_dataset_creator.py index f3a4e3adc8..fca3fea8ba 100644 --- a/ocpmodels/datasets/deup_dataset_creator.py +++ b/ocpmodels/datasets/deup_dataset_creator.py @@ -441,13 +441,13 @@ def parse_args(): parser.add_argument( "--checkpoints", nargs="+", - default="/network/scratch/a/alexandre.duval/ocp/runs/4616500/", + default="/network/scratch/a/alexandre.duval/ocp/runs/4648581/", help="Paths to the checkpoints", ) parser.add_argument( "--dropout", type=float, - default=0.2, + default=0.3, help="Dropout value", ) return parser.parse_args() diff --git a/scripts/gnn_dev.py b/scripts/gnn_dev.py index bc22055536..bc3924fbe2 100644 --- a/scripts/gnn_dev.py +++ b/scripts/gnn_dev.py @@ -16,7 +16,7 @@ if __name__ == "__main__": config = {} # Customize args - config["graph_rewiring"] = "remove-tag-0" + config["graph_rewiring"] = "" config["frame_averaging"] = "2D" config["fa_method"] = "random" # "random" config["test_ri"] = False @@ -29,10 +29,9 @@ str_args = sys.argv[1:] if all("config" not in arg for arg in str_args): str_args.append("--is_debug") - # str_args.append("--config=faenet-is2re-all") - str_args.append("--config=faenet-is2re-10k") - str_args.append("--adsorbates={'*O', '*OH', '*OH2', '*H'}") - # str_args.append("--is_disconnected=True") + str_args.append("--config=deup_depfaenet-deup_is2re-10k") + # str_args.append("--adsorbates={'*O', '*OH', '*OH2', '*H'}") + str_args.append("--is_disconnected=True") # str_args.append("--silent=0") warnings.warn( "No model / mode is given; chosen as default" + f"Using: {str_args[-1]}" From 4d73707b4399c29970ee68c6ed870d74660837e2 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Fri, 26 Apr 2024 03:48:24 -0400 Subject: [PATCH 21/45] test use deup-dataset in an active learning framework --- scripts/active_learning.py | 97 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 scripts/active_learning.py diff --git a/scripts/active_learning.py b/scripts/active_learning.py new file mode 100644 index 0000000000..9d8d33bb61 --- /dev/null +++ b/scripts/active_learning.py @@ -0,0 +1,97 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" +import sys +import warnings +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from ocpmodels.common.utils import make_script_trainer, make_trainer_from_dir +from ocpmodels.common.gfn import FAENetWrapper +from ocpmodels.trainers import SingleTrainer +from ocpmodels.datasets.lmdb_dataset import DeupDataset +from ocpmodels.datasets.data_transforms import get_transforms + +if __name__ == "__main__": + + deup_dataset_chkpt = "/network/scratch/a/alexandre.duval/ocp/runs/4657270/deup_dataset" + model_chkpt = "/network/scratch/a/alexandre.duval/ocp/runs/4648581/checkpoints/best_checkpoint.pt" + + data_config = { + "default_val": "deup-val_ood_cat-val_ood_ads", + "deup-train-val_id": { + "src": deup_dataset_chkpt + }, + "deup-val_ood_cat-val_ood_ads": { + "src": deup_dataset_chkpt + }, + "train": { + "src": "/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/train/", + "normalize_labels": True, + }, + "val_id": { + "src": "/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_id/" + }, + "val_ood_cat": { + "src": "/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_cat/" + }, + "val_ood_ads": { + "src": "/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_ads/" + }, + "val_ood_both": { + "src": "/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_both/" + }, + } + + trainer = make_trainer_from_dir( + model_chkpt, + mode="continue", + overrides={ + "is_debug": True, + "silent": True, + "cp_data_to_tmpdir": False, + "config": "depfaenet-deup_is2re-all", + "deup_dataset.create": False, + "dataset": data_config, + }, + silent=True, + ) + + wrapper = FAENetWrapper( + faenet=trainer.model, + transform=get_transforms(trainer.config), + frame_averaging=trainer.config.get("frame_averaging", ""), + trainer_config=trainer.config, + ) + + wrapper.freeze() + loaders = trainer.loaders + + data_gen = iter(loaders["deup-train-val_id"]) + batch = next(data_gen) + preds = wrapper(batch) + + # trainer.config["dataset"].update({ + # "deup-train-val_id": { + # "src": "/network/scratch/s/schmidtv/ocp/runs/3301084/deup_dataset" + # }, + # "deup-val_ood_cat-val_ood_ads": { + # "src": "/network/scratch/s/schmidtv/ocp/runs/3301084/deup_dataset" + # }, + # "default_val": "deup-val_ood_cat-val_ood_ads" + # }) + + # deup_dataset_path = "/network/scratch/a/alexandre.duval/ocp/runs/4642835/deup_dataset" + # deup_dataset = DeupDataset( + # { + # **trainer.config["dataset"], + # }, + # "deup-train-val_id", + # transform=get_transforms(trainer.config), + # ) + + # deup_sample = deup_dataset[0] \ No newline at end of file From 175567efaebe65587de13c74b1b8b01415bb509f Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Fri, 17 May 2024 10:58:25 -0400 Subject: [PATCH 22/45] deupdepfaenet configs --- configs/exps/deup/uncertainty/deup_depfaenet.yaml | 6 +++--- configs/models/deup_depfaenet.yaml | 2 +- configs/models/deup_faenet.yaml | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/exps/deup/uncertainty/deup_depfaenet.yaml b/configs/exps/deup/uncertainty/deup_depfaenet.yaml index 7b3ccd8a1e..07f4f142f8 100644 --- a/configs/exps/deup/uncertainty/deup_depfaenet.yaml +++ b/configs/exps/deup/uncertainty/deup_depfaenet.yaml @@ -5,7 +5,7 @@ job: partition: long default: - config: deup_faenet-deup_is2re-all + config: deup_depfaenet-deup_is2re-all wandb_project: ocp-deup wandb_tags: deup-depfaenet, 4648581-model, 4657270-dataset test_ri: True @@ -16,8 +16,8 @@ default: dropout_lin: 0.3 cp_data_to_tmpdir: false inference_time_loops: 1 - restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4648581/ - checkpoint: /network/scratch/a/alexandre.duval/ocp/runs/4648581/ + # restart_from_dir: /network/scratch/a/alexandre.duval/ocp/runs/4648581/ + # checkpoint: /network/scratch/a/alexandre.duval/ocp/runs/4648581/ dataset: # mandatory if restart_from_dir is set default_val: deup-val_ood_cat-val_ood_ads deup-train-val_id: diff --git a/configs/models/deup_depfaenet.yaml b/configs/models/deup_depfaenet.yaml index 24ab2587c9..be23501f7f 100644 --- a/configs/models/deup_depfaenet.yaml +++ b/configs/models/deup_depfaenet.yaml @@ -39,7 +39,7 @@ default: res_updown: hidden_channels: 128 norm: batch1d # batch1d, layer or null - deup_features: [s, energy_pred_std] + deup_features: [s, energy_pred_std] # add q for density optim: batch_size: 256 eval_batch_size: 256 diff --git a/configs/models/deup_faenet.yaml b/configs/models/deup_faenet.yaml index f6e52681f1..c11f6e450d 100644 --- a/configs/models/deup_faenet.yaml +++ b/configs/models/deup_faenet.yaml @@ -39,7 +39,7 @@ default: res_updown: hidden_channels: 128 norm: batch1d # batch1d, layer or null - deup_features: [s, energy_pred_std] + deup_features: [s, energy_pred_std] # add q for density if it exists optim: batch_size: 256 eval_batch_size: 256 From 03f30388994907671668c6b77396f23bf1b84e3a Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Fri, 17 May 2024 11:01:41 -0400 Subject: [PATCH 23/45] fix issues with q + enforce graph-level deup-dataset --- ocpmodels/datasets/deup_dataset_creator.py | 2 ++ ocpmodels/models/depfaenet.py | 2 +- ocpmodels/models/deup_depfaenet.py | 20 +++++++++++--------- ocpmodels/models/deup_faenet.py | 15 ++++++++------- ocpmodels/models/faenet.py | 6 ++++-- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/ocpmodels/datasets/deup_dataset_creator.py b/ocpmodels/datasets/deup_dataset_creator.py index fca3fea8ba..8b6e193c5b 100644 --- a/ocpmodels/datasets/deup_dataset_creator.py +++ b/ocpmodels/datasets/deup_dataset_creator.py @@ -329,6 +329,7 @@ def create_deup_dataset( pred_mean, batch.y_relaxed.to(pred_mean.device) ) # Store deup samples + assert len(preds["q"]) == len(batch) deup_samples += [ { "energy_target": batch.y_relaxed.clone(), @@ -481,6 +482,7 @@ def parse_args(): # base_config = make_config_from_conf_str("faenet-is2re-all") # base_datasets_config = base_config["dataset"] + # Load deup dataset deup_dataset = DeupDataset( { **base_datasets_config, diff --git a/ocpmodels/models/depfaenet.py b/ocpmodels/models/depfaenet.py index 87f76e08c1..4e83dbc0ec 100644 --- a/ocpmodels/models/depfaenet.py +++ b/ocpmodels/models/depfaenet.py @@ -92,6 +92,6 @@ def __init__(self, **kwargs): def energy_forward(self, data, q=None): # We need to save the tags so this step is necessary. self.output_block.tags_saver(data.tags) - pred = super().energy_forward(data) + pred = super().energy_forward(data, q) return pred diff --git a/ocpmodels/models/deup_depfaenet.py b/ocpmodels/models/deup_depfaenet.py index 8457acf45d..619ff6a68a 100644 --- a/ocpmodels/models/deup_depfaenet.py +++ b/ocpmodels/models/deup_depfaenet.py @@ -30,6 +30,7 @@ def __init__( ) def forward(self, h, edge_index, edge_weight, batch, alpha, data=None): + # If sample density is used as feature, we need to add the extra dimension if self._set_q_dim: assert data is not None assert "deup_q" in data.to_dict().keys() @@ -58,13 +59,14 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None): }: h = h * alpha - # Global pooling -- get final graph rep - out = scatter( - h, - batch, - dim=0, - reduce="mean" if self.deup_extra_dim > 0 else "add", - ) + # Pool into a graph rep if necessary + if len(h) > len(batch): + h = scatter( + h, + batch, + dim=0, + reduce="mean" if self.deup_extra_dim > 0 else "add", + ) # Concat graph representation with deup features (s, kde(q), std) # and apply MLPs @@ -76,7 +78,7 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None): + f" from the data dict ({data_keys})" ) out = torch.cat( - [out] + [h] + [data[f"deup_{k}"][:, None].float() for k in self.deup_features], dim=-1, ) @@ -87,7 +89,7 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None): return out @registry.register_model("deup_depfaenet") -class DeupFAENet(DepFAENet): +class DeupDepFAENet(DepFAENet): def __init__(self, *args, **kwargs): kwargs["dropout_edge"] = 0 super().__init__(*args, **kwargs) diff --git a/ocpmodels/models/deup_faenet.py b/ocpmodels/models/deup_faenet.py index 88a55964c7..726e0e3503 100644 --- a/ocpmodels/models/deup_faenet.py +++ b/ocpmodels/models/deup_faenet.py @@ -58,12 +58,13 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None): h = h * alpha # Global pooling -- get final graph rep - out = scatter( - h, - batch, - dim=0, - reduce="mean" if self.deup_extra_dim > 0 else "add", - ) + if len(h) > len(batch): + h = scatter( + h, + batch, + dim=0, + reduce="mean" if self.deup_extra_dim > 0 else "add", + ) # Concat graph representation with deup features (s, kde(q), std) # and apply MLPs @@ -75,7 +76,7 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None): + f" from the data dict ({data_keys})" ) out = torch.cat( - [out] + [h] + [data[f"deup_{k}"][:, None].float() for k in self.deup_features], dim=-1, ) diff --git a/ocpmodels/models/faenet.py b/ocpmodels/models/faenet.py index 7d5c6044da..78b9980cd7 100644 --- a/ocpmodels/models/faenet.py +++ b/ocpmodels/models/faenet.py @@ -711,7 +711,7 @@ def energy_forward(self, data, q=None): edge_attr = edge_attr[edge_mask] rel_pos = rel_pos[edge_mask] - if q is None: + if not hasattr(data, "deup_q"): # Embedding block h, e = self.embed_block(z, rel_pos, edge_attr, data.tags) @@ -754,6 +754,7 @@ def energy_forward(self, data, q=None): # WARNING # q which is NOT the hidden state h if it was stored as a scattered # version of h. This works for GPs, NOT for MC-dropout + q = data.deup_q # No need to clone # TODO: check that it's not a problem (move to deup models) h = q alpha = None @@ -766,7 +767,8 @@ def energy_forward(self, data, q=None): elif self.skip_co == "add": energy = sum(energy_skip_co) - if q and len(q) > len(energy): + # Store graph-level representation. # TODO: maybe want node-level rep + if q is not None and len(q) > len(energy): # N_atoms x hidden_channels q = scatter(q, batch, dim=0, reduce="mean") # N_graphs x hidden_channels preds = { From ae7b17559354e3cc5e3af29bee4b363a32c7c987 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Mon, 20 May 2024 06:21:08 -0400 Subject: [PATCH 24/45] random instead of randon in yaml --- configs/models/faenet.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/models/faenet.yaml b/configs/models/faenet.yaml index 3e66bba5fd..97d1136327 100644 --- a/configs/models/faenet.yaml +++ b/configs/models/faenet.yaml @@ -1,6 +1,6 @@ default: frame_averaging: "" # {"2D", "3D", "DA", ""} - fa_method: "" # {"", all, randon, det, se3-all, se3-randon, se3-det} + fa_method: "" # {"", all, random, det, se3-all, se3-random, se3-det} model: name: faenet act: swish @@ -69,7 +69,7 @@ is2re: default: graph_rewiring: remove-tag-0 frame_averaging: "2D" # {"2D", "3D", "DA", ""} - fa_method: "se3-random" # {"", all, randon, det, se3-all, se3-randon, se3-det} + fa_method: "se3-random" # {"", all, random, det, se3-all, se3-random, se3-det} # *** Important note *** # The total number of gpus used for this run was 1. # If the global batch size (num_gpus * batch_size) is modified From 7c2714cc51c35a09875c577c647371d8e9c7634c Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Mon, 20 May 2024 06:23:13 -0400 Subject: [PATCH 25/45] random, not stochastic --- ocpmodels/datasets/data_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ocpmodels/datasets/data_transforms.py b/ocpmodels/datasets/data_transforms.py index 17a63dfa52..db01cbb1e6 100644 --- a/ocpmodels/datasets/data_transforms.py +++ b/ocpmodels/datasets/data_transforms.py @@ -41,11 +41,11 @@ class FrameAveraging(Transform): Can be 2D FA, 3D FA, Data Augmentation or no FA, respectively denoted by (`"2D"`, `"3D"`, `"DA"`, `""`) fa_method (str): the actual frame averaging technique used. - "stochastic" refers to sampling one frame at random (at each epoch), "det" + "random" refers to sampling one frame at random (at each epoch), "det" to chosing deterministically one frame, and "all" to using all frames. The prefix "se3-" refers to the SE(3) equivariant version of the method. "" - means that no frame averaging is used. (`""`, `"stochastic"`, `"all"`, - `"det"`, `"se3-stochastic"`, `"se3-all"`, `"se3-det"`) + means that no frame averaging is used. (`""`, `"random"`, `"all"`, + `"det"`, `"se3-random"`, `"se3-all"`, `"se3-det"`) Returns: (data.Data): updated data object with new positions (+ unit cell) attributes From b006540fe4027a8a82854fbd3b1901411908c234 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Tue, 21 May 2024 09:59:08 -0400 Subject: [PATCH 26/45] signnet analysis (workshop submission) --- scripts/signnet.py | 116 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 scripts/signnet.py diff --git a/scripts/signnet.py b/scripts/signnet.py new file mode 100644 index 0000000000..9d1dfb8435 --- /dev/null +++ b/scripts/signnet.py @@ -0,0 +1,116 @@ +import sys +from pathlib import Path +import torch + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from ocpmodels.common.utils import make_script_trainer +from ocpmodels.trainers import SingleTrainer +from torch_geometric.data import Batch + +if __name__ == "__main__": + config = {} + + # Customize args + config["graph_rewiring"] = "remove-tag-0" + config["frame_averaging"] = "3D" + config["fa_method"] = "all" + config["test_ri"] = False + # config["optim"] = {"batch_size": 1} + + str_args = sys.argv[1:] + if all("config" not in arg for arg in str_args): + str_args.append("--is_debug") + str_args.append("--config=faenet-is2re-10k") + + # Create trainer + trainer: SingleTrainer = make_script_trainer(str_args=str_args, overrides=config) + + for batch in trainer.loaders["train"]: + break + b = batch[0] + rotated_b = b.clone() + rotated_b = trainer.rotate_graph(rotated_b, rotation="z") + rotation_matrix = rotated_b["rot"] + rotated_b = rotated_b["batch_list"][0] + + # Check: X' = X R (or X = X' R^T) + assert torch.allclose(rotated_b[0].pos @ rotation_matrix.T, b[0].pos, atol=1e-04) + assert torch.allclose(b[0].pos @ rotation_matrix, rotated_b[0].pos, atol=1e-04) + # Check: X U_i = X' U_i (compare X_fa and X'fa, abs values to deal with different frames) + assert torch.allclose( + torch.abs(b[0].pos @ b[0].fa_rot[0].squeeze(0)), + torch.abs(rotated_b[0].pos @ rotated_b[0].fa_rot[0].squeeze(0)), + atol=10e-03, + ) + # Check: U_i' = R U_i + + # SignNet model + class SignNet(torch.nn.Module): + def __init__(self, in_channels=3, hidden_channels=12, out_channels=3): + super(SignNet, self).__init__() + self.mlp = torch.nn.Sequential( + torch.nn.Linear(in_channels, hidden_channels), + torch.nn.ReLU(), + torch.nn.Linear(hidden_channels, out_channels), + ) + torch.nn.init.xavier_uniform_(self.mlp[0].weight) + torch.nn.init.xavier_uniform_(self.mlp[2].weight) + self.mlp2 = torch.nn.Linear(3 * out_channels, 3 * out_channels) + + torch.nn.init.xavier_uniform_(self.mlp2.weight) + + def forward(self, x, second_mlp=False): + if second_mlp: + res = self.mlp(x) + self.mlp(-x) + res = res.view(-1) # flatten res + res = self.mlp2(res) + return res.view((3, -1)).T # reshape as eigenvector column matrix + return (self.mlp(x) + self.mlp(-x)).T + + signnet = SignNet() + second_mlp = True + + for i in range(len(b.sid)): + g = Batch.get_example(b, i) + rotated_g = Batch.get_example(rotated_b, i) + + # Test: X_fa = R X_fa' + torch.allclose(rotation_matrix @ rotated_g.fa_rot[0], g.fa_rot[0], atol=5e-01) + + # SignNet on eigenvector matrix U for g and rotated_g + # Need SignNet(U_i) = U*, for every frame U_i + # Eigenvectors are the columns of fa_rot. Need rows for SignNet MLPs + eigen = signnet(g.fa_rot[0].squeeze(0).T, second_mlp) + eigen_bis = signnet(g.fa_rot[1].squeeze(0).T, second_mlp) + assert torch.allclose(eigen, eigen_bis, atol=1e-04) + + # Compare with rotated graph + rot_eigen = signnet(rotated_g.fa_rot[0].squeeze(0).T, second_mlp) + # Check U*' = R U* + if torch.allclose(rot_eigen, eigen, atol=1e-4): + print("U* is invariant to rotations") + elif torch.allclose(rot_eigen, rotation_matrix @ eigen, atol=1e-4): + print("U* is equivariant to rotations") + else: + print("U* is neither invariant nor equivariant") + # Double-Check: X U* = X' U*' + new_pos = g.pos @ eigen + new_rotated_pos = rotated_g.pos @ rot_eigen + if not torch.allclose(new_pos, new_rotated_pos, atol=1e-4): + print("No equivariance: X U* != X' U*'") + + # Different eigenvalues matrix => want different U* + m = g.fa_rot[0].squeeze(0).T + torch.randn(3, 3) + e = signnet(m, second_mlp) + if torch.allclose(e, eigen, atol=1e-4): + print("Issue: distinct graph has same signnet eigenvectors") + + # Same but on real eigenvec matrix + next_g = Batch.get_example(b, i+1) + e = signnet(next_g.fa_rot[0].squeeze(0).T, second_mlp) + if torch.allclose(e, eigen, atol=1e-4): + print("Issue: distinct graph has same signnet eigenvectors") + + # Try with more complex network + # Repalce False by True in signnet above From cc503353961b78ea83294c9d72444702bec169b4 Mon Sep 17 00:00:00 2001 From: AlexDuvalinho Date: Thu, 31 Oct 2024 14:42:47 -0400 Subject: [PATCH 27/45] denormalise predictions --- ocpmodels/common/gfn.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 0a1f45521c..0e4ba73893 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -20,6 +20,7 @@ def __init__( transform: Callable = None, frame_averaging: str = None, trainer_config: dict = None, + normalizers: dict = None, ): """ `FAENetWrapper` is a wrapper class for the FAENet model. It is used to perform @@ -31,6 +32,7 @@ def __init__( frame_averaging (str, optional): The frame averaging method to use. trainer_config (dict, optional): The trainer config used to create the model. Defaults to None. + normalizers (dict, optional): The normalizers used to create the model. """ super().__init__() @@ -39,6 +41,7 @@ def __init__( self.frame_averaging = frame_averaging self.trainer_config = trainer_config self._is_frozen = None + self.normalizers = normalizers @property def frozen(self): @@ -165,7 +168,15 @@ def forward( if retrieve_hidden: return preds - return preds["energy"] # denormalize? + breakpoint() + + # Denormalize predictions + preds["energy"] = self.normalizers["target"].denorm( + preds["energy"], + ) + # preds["energy"] = preds["energy"].to(torch.float16) + + return preds["energy"] def freeze(self): """Freeze the model parameters.""" @@ -274,6 +285,7 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: transform=get_transforms(trainer.config), frame_averaging=trainer.config.get("frame_averaging", ""), trainer_config=trainer.config, + normalizers=trainer.normalizers, ) wrapper.freeze() loaders = trainer.loaders @@ -288,10 +300,10 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: from ocpmodels.common.gfn import prepare_for_gfn ckpt_paths = {"mila": "/path/to/releases_dir"} - release = "v2.3_graph_phys" + release = "0.0.1" # or ckpt_paths = { - "mila": "/network/scratch/s/schmidtv/ocp/runs/3789733/checkpoints/best_checkpoint.pt" + "mila": "/network/scratch/a/alexandre.duval/ocp/catalyst-ckpts/0.0.1/best_checkpoint.pt" } release = None wrapper, loaders = prepare_for_gfn(ckpt_paths, release) From c373d7000b5c31872b55db3a93abf33790e41d2f Mon Sep 17 00:00:00 2001 From: Elena Podina Date: Thu, 28 Nov 2024 15:03:04 -0500 Subject: [PATCH 28/45] commented out surface tagging --- ocdata/surfaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocdata/surfaces.py b/ocdata/surfaces.py index b96f571e81..adb8e1ca29 100644 --- a/ocdata/surfaces.py +++ b/ocdata/surfaces.py @@ -107,7 +107,7 @@ def __init__( ).reduced_formula ), "Mismatched bulk and surface" - self.tag_surface_atoms(self.bulk_object.bulk_atoms, self.surface_atoms) + #self.tag_surface_atoms(self.bulk_object.bulk_atoms, self.surface_atoms) self.constrained_surface = constrain_surface(self.surface_atoms) def tile_atoms(self, atoms): From 5996ebb61a07d19a10630d8ab2e0760114935ed9 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Sun, 1 Dec 2024 13:26:12 -0500 Subject: [PATCH 29/45] added load_datasets flag to is2re.yaml --- configs/models/tasks/is2re.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/models/tasks/is2re.yaml b/configs/models/tasks/is2re.yaml index 4b1fe4304f..e237720aa6 100644 --- a/configs/models/tasks/is2re.yaml +++ b/configs/models/tasks/is2re.yaml @@ -18,6 +18,7 @@ default: mode: train adsorbates: all # {"*O", "*OH", "*OH2", "*H"} adsorbates_ref_dir: /network/scratch/s/schmidtv/ocp/datasets/ocp/per_ads + load_datasets: False dataset: default_val: val_id train: From da5c7a9f2ac84bdf1d13e4c1654cdb958604c638 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Sun, 1 Dec 2024 13:30:25 -0500 Subject: [PATCH 30/45] made load_data config-based in base_trainer.py --- ocpmodels/trainers/base_trainer.py | 62 ++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index e871027efe..a6710a1d12 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -194,11 +194,54 @@ def __init__(self, load=True, **kwargs): task=self.task_name, model_regresses_forces=self.config["model"].get("regress_forces", ""), ) + + def init_normalizer(self): + self.normalizers = {} + if self.normalizer.get("normalize_labels", False): + if "target_mean" in self.normalizer: + self.normalizers["target"] = Normalizer( + mean=self.normalizer["target_mean"], + std=self.normalizer["target_std"], + device=self.device, + ) + if "hof_stats" in self.normalizer: + self.normalizers["target"].set_hof_rescales( + self.normalizer["hof_stats"] + ) + + def init_normalizer_from_target_mean(self): + if not hasattr(self,'normalizers'): + self.normalizers = {} + self.normalizers["target"] = Normalizer( + mean=self.normalizer["target_mean"], + std=self.normalizer["target_std"], + device=self.device, + ) + if "hof_stats" in self.normalizer: + self.normalizers["target"].set_hof_rescales( + self.normalizer["hof_stats"] + ) + + def init_normalizer_from_data(self): + if not hasattr(self,'normalizers'): + self.normalizers = {} + self.normalizers["target"] = Normalizer( + tensor=self.datasets["train"].data.y[ + self.datasets["train"].__indices__ + ], + device=self.device, + ) + def load(self): self.load_seed_from_config() self.load_logger() - self.load_datasets() + if self.config["load_datasets"]: + self.load_datasets() + else: + self.init_normalizer_from_target_mean() + # Dataset loader already initializes the normalizer, so if we don't + # load the dataset, then we have to initialize the normalizer from target_mean here self.load_task() self.load_model() self.load_loss() @@ -383,22 +426,9 @@ def load_datasets(self): self.normalizers = {} if self.normalizer.get("normalize_labels", False): if "target_mean" in self.normalizer: - self.normalizers["target"] = Normalizer( - mean=self.normalizer["target_mean"], - std=self.normalizer["target_std"], - device=self.device, - ) - if "hof_stats" in self.normalizer: - self.normalizers["target"].set_hof_rescales( - self.normalizer["hof_stats"] - ) + self.init_normalizer_from_target_mean() else: - self.normalizers["target"] = Normalizer( - tensor=self.datasets["train"].data.y[ - self.datasets["train"].__indices__ - ], - device=self.device, - ) + self.init_normalizer_from_data() @abstractmethod def load_task(self): From ebf196f40c7602c4f3ca6094fb5ca342b2bab651 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Sun, 1 Dec 2024 14:20:08 -0500 Subject: [PATCH 31/45] removed breakpoint(), fixed FAENet denormalization and checkpoint loading --- ocpmodels/common/gfn.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 0e4ba73893..a6f8fe4808 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -11,6 +11,7 @@ from ocpmodels.common.utils import make_trainer_from_dir, resolve from ocpmodels.models.faenet import FAENet from ocpmodels.datasets.data_transforms import get_transforms +from ocpmodels.modules.normalizer import Normalizer class FAENetWrapper(nn.Module): @@ -80,7 +81,6 @@ def preprocess(self, batch: Union[Batch, Data, List[Data], List[Batch]]): and collate them into a Batch. .. code-block:: python - In [7]: %timeit wrapper.preprocess(batch) The slowest run took 4.94 times longer than the fastest. This could mean that an intermediate result is being cached. @@ -168,8 +168,8 @@ def forward( if retrieve_hidden: return preds - breakpoint() + # Denormalize predictions preds["energy"] = self.normalizers["target"].denorm( preds["energy"], @@ -278,7 +278,10 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: "cp_data_to_tmpdir": False, }, silent=True, + skip_imports=["qm7x", "gemnet", "spherenet", "painn", "comenet"] ) + trainer.init_normalizer() + trainer.load_checkpoint(ckpt_path) wrapper = FAENetWrapper( faenet=trainer.model, @@ -303,7 +306,9 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: release = "0.0.1" # or ckpt_paths = { - "mila": "/network/scratch/a/alexandre.duval/ocp/catalyst-ckpts/0.0.1/best_checkpoint.pt" + "mila": "/network/scratch/a/alexandre.duval/ocp/catalyst-ckpts/0.0.1/best_checkpoint.pt", + "lpodina": "/home/felixt/shared/checkpoints/best_checkpoint.pt", + "narval": "/home/felixt/shared/checkpoints/best_checkpoint.pt" } release = None wrapper, loaders = prepare_for_gfn(ckpt_paths, release) From 5787d10a4e75cd0c576310b131eeb4d2a082e052 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Tue, 14 Jan 2025 14:40:48 -0500 Subject: [PATCH 32/45] added train set pickles for graphnorm --- .../embeddings/60_train_graphnorm_mean_0.pickle | Bin 0 -> 1556 bytes .../embeddings/60_train_graphnorm_mean_1.pickle | Bin 0 -> 1556 bytes .../embeddings/60_train_graphnorm_mean_2.pickle | Bin 0 -> 1556 bytes .../embeddings/60_train_graphnorm_mean_3.pickle | Bin 0 -> 1556 bytes .../embeddings/60_train_graphnorm_mean_4.pickle | Bin 0 -> 1556 bytes .../embeddings/60_train_graphnorm_var_0.pickle | Bin 0 -> 1556 bytes .../embeddings/60_train_graphnorm_var_1.pickle | Bin 0 -> 1556 bytes .../embeddings/60_train_graphnorm_var_2.pickle | Bin 0 -> 1556 bytes .../embeddings/60_train_graphnorm_var_3.pickle | Bin 0 -> 1556 bytes .../embeddings/60_train_graphnorm_var_4.pickle | Bin 0 -> 1556 bytes 10 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_mean_0.pickle create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_mean_1.pickle create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_mean_2.pickle create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_mean_3.pickle create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_mean_4.pickle create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_var_0.pickle create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_var_1.pickle create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_var_2.pickle create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_var_3.pickle create mode 100644 ocpmodels/datasets/embeddings/60_train_graphnorm_var_4.pickle diff --git a/ocpmodels/datasets/embeddings/60_train_graphnorm_mean_0.pickle b/ocpmodels/datasets/embeddings/60_train_graphnorm_mean_0.pickle new file mode 100644 index 0000000000000000000000000000000000000000..44a4dc491e08f3536dac7382ef4060f2a403643c GIT binary patch literal 1556 zcmZvceO!!L6vt<35~YYp+LG*tWF)_HoIj%8?k*vN_JVFwX_tXwfijp?)l@K&pr2hf9KqPej9Cyh71yX z2-8qjr&lEV%8jgAV=$XW$n~lOqgtbsvns7YrCWnuhfiLk%Dzj<0T%(Iu z$`kdv6nR1ht1|FZ=odBu!A7g?cA9|$;r*U#k1lle-Dnm4#G20=7=gEJI6pE~uVZxz zokpIf(i_w|EhFRsnFD{LQM0T@C0A*c>Ub?O>bP>+Z4}#QRjA(XV6_-^ag)OfoU`TX@J*Fk} zwN^Lc`m6$Mei{b>n@o6QX%1;KodQYsuc(o@(Pz63fN#%WTrqJdj_bLFoNrjyH9IhH zLNN`^UaX)(#cCoi568{!^SS1R@42Rkb}q%%k3Qub1nWY5NTuI>inDa%qj^H=LCRCK z+Pww(?7Ks%oHot0cCG0cnsi|o4sE}R zIC&>ZtK3Es4!ffpj&sOKyLa6C2pRQmjHrIap%L}ATUxlthsosp?pY!JR18=D>jj@JBMmy5RkDIWvw#OUeY2cpBe@k7apLrfo@czd5hOSv#Q^* z={9VpMEItq7y?Th>lM3}!}wtbF$omW#~N*5cTEB+6nc|2QA?1%tP-!dA3;91X@DQE zcp+|q2l*_RK@IP>0u|GS>aJ&!y-k5!>)}Y!i(R1RCWr3~R^pM50^oK5(4CWWX>mv- zm43(rn#b%%<0{%nS$q&7lgiNjt6iWp#Nh26<;eA79CR!{N+Q{8sIbaLdvsp#(q|1Q zrCxA`PQk7_hSjenMYve*0P>DnbY)Z)NPdnc_nHjk>z&b1Q8|{_botZHhFCi6oR+-a z6$d4bW%!fot44x!f_VyjYuY=5|c>oR10Kh#DdE|PSEk(6{53yP)%PZem<`l zrpo!+?M=g6f8uBl%Bw3CR zKrq}Ed(`&ed>;m-753n_)hS@kHA~GBfkeo=R-Tzp629b{?0LJ@s1z*kw!FjJFVHJh UdX+Lvtzdc6o-8%9{43vo0UdFiXaE2J literal 0 HcmV?d00001 diff --git a/ocpmodels/datasets/embeddings/60_train_graphnorm_mean_1.pickle b/ocpmodels/datasets/embeddings/60_train_graphnorm_mean_1.pickle new file mode 100644 index 0000000000000000000000000000000000000000..3b95a6be418791f8e010c6d316d488da286b2939 GIT binary patch literal 1556 zcmZvcdsxg_7{_OtQCpIO|d`xyYC5aY;&*T*@dC^=h3od5UqD(LHy-S4_{9*m#98*Iuep$I7KY z>C_3*xOI#|&rvfr_T%x=EV6A>y+yPC6|E`_AL^B6G5?8x%PCr(yTqOwsnMw!wM?y& zCM$G$r8-f|=K#q7?nb3#7?nb*NR%sM6SWrGBsN@=S*KA*^~&`My;hK-6}n5f5j_1$ zwT{vD3$>64xRWxG(OQNICBn}d>cn`hRj4p2iJR#zvHA>uU0KVMeffgTY;JKT>TAt`4T?M}XP22R$0O zpS-AOMCEhWAS)YNywuq_xGLfV`!wh@>UN*WPU^Q*oHu%fYt16Hd*vBOZUv70)?ofvfyius2Te@jgic_UBvTp3ZRb#mYYXqDL6Cp?3=L z8h8tQPm5^!t-WOOGiTbDaS!uGZO3!(H?Vh)enZ2e@=>$B2TlpELYuFY0ZnzoBsLi8 zf-^z>ZVkA;HIPp$Q&3aBKUy-gfjxgIhRqWVBsx(sKCa(K#tDi8)hG-M@oi*LMl|Gk z8&Uh^05X4(hT7Eolc#I>RO|XZD6Tf3c!z1YuzDG!U+F?|ifDW8I5QY!ujig{lRQu(^ zjsq_E&*)=-9ZJc_S9&_|A9r|td%r76vxPLTuM$UeBmAyXPi9Pyq}G;k)UT?H-553ud6`zAu7zgc z|6T-@sFS|k;lnBpJA!(^RJuKVH|p7}BM&$GgZ0WrHur2Pd^8L~mHF)`;m|eltY`?5 zJ{XDTVF8#kyg~NTB=&f80ht_~q4;eyzSHOn&-{AO@P$5D-F=RV=AJ+gtr4E1T!9M$ zkCBEfQ_#+zKM=RB_b8tXhqhUTxTmod?6U*dq#sR)iP}q?H=Khgshrl08IND7Gl-E@ zu?}0ZNlD9kdQ;O2E%n*7s5Js_H!P`!#W5&Lx`4_br;&(Ad)jT@L@#tS(At-cWNSwy zYxn5@4e9Wv)wM3rcI7hJT3rVP2dweJ{6lc2{Z}?4_8?g#GU3Yb3QF&)fGHk}mroi4 zF36P~mwt}DQC@=FJ{j1Gv6IA2LLoL9jFHArUMQb)txO}ALb;Na+n=*rl|sgFZp%4* Y{{=d^LZ^_2C}j+1+8&WchI{4pA7k$Q{bpbbUYb#R(D55cF% zbKwj|O^UB7jnnB(mONLLQJa*e(!chR2G@1Ut6mSLj|%=X-7q$@o}%>L6!B;&y(zTbh{sth2{Fdx70 zESNgBVwzN8Lq1F!dh5zSS+|k+d)H(A7i-~SV>KN+&kQdz0&)LyVQ`~&B>k~24o{nr z!2F^XP^()5UED^o@!3xh`E(V!IeagTJC{N6mKK~JHVdza@UvH_2SaAvezd&Umv&qm zhBd`zX7zR<$OFCsIz5)wjEn?S@eW%3S?qXd4Wm&NqHb}DU1AKyXVXVB9l@PYHf)Rh zl$V3-gvp-?-yVyS-WTD?Upvz^KRhB`1Gb_B-z(ta_6{eQs^H#r7baq8eESLbX0y6T-l5Ip;h#f@U-mBh-WvgQYvo7~#m+`W3rA3=lnprQ)NN#);seQ3 zfkwXABwxPIOh;74pupm})YnF!@@zf0Ms=YKNg%Zq{RSJZ?}Q`wwxXSn-I@I2@wj($ zUo-^HO^(q7K)3w^@zcb~f|~QOwXY*N8rCUaHD(Wf(0G>Cbn39+#xr!mo{D^^0#1$} z1mr4LTj??ZzBE6fADYW$jfsow&nwcgSP_XmN3JG2jwx`?-h1TE$+Hk|n9LjKR=kaNcwNR-UL`QIawPL+T@GGp71joo?uDrR<;4EG9mX68 z#$yZ0v1!sAdt$UFC|=YsF<~OqT3Ck{C>?P4hl|i2JPMfWGCW}Refc2UHe45xjxN2l zL*W;#sAK#^l#vk(*CYr$4Ga~CQej!M6`b?M`0ut=xF9r_der1Yeg6n@yyiSwwr&OZ zc&%hI&4jjl1D*57G>lfCK{*9(_(@d-ML!Qk-3PYfT^kmm!;f-s!qWyat6w55GloFl zkUeDB!+y{!yBY$o%)tMf!2H?qP|2%d3|}JxGd+ z6(06@&0K^|1|Y}Isr1=FJ5;;gh2)z`NI#eYqO}F^Bq);EU!;`pw{~FLjzlQU-)1*A zO@cepakbCghe2uWda~)-SMdCTjx@RMhFR~=5w8eO>Vxu-b>U`+7d;@7?jRQh93t&}_YlCvIT3ZxcP^(Cv^0)WLy=U%x_xtY5kI(E_>gpu; z5SIQXy+OTHs!BI$bw;bjLuJq`N!RL0iae$Sy#OXi#-#p&t1rk~RJJ-BaWPN4qq1W7o*xM@yb zQ^E7yLdFu*KcEq1R7PCSvnt`JJDeX-;Dx__yDH zk3A1j#@~+SJ%52Rx|bo5#)mvD2**#qCL}p?KEC(x6V%%z#@DW7Vxs>R`P5f3wu)8k znC^wR_!sUPDo98Q&)Awf(*(*^50KhL=aPiXB_lk{|@{T|5X^m;?5Z~)>X|8RnJDr z&S&x3*_-fTB!I(v{Auw3k;!vWmPEI<0GzzDpQvg${NnW_kah>qkftff7j(!iQwKkVZh@2vHM=$ZK90Fq zLZ?2z2xeywYWua6dTAXYL8pYbn!oY76bCwbaSgJ2I^$Wjjf}5dptusd6dW6>>5#?w zHcP33+Nc|~Y6ZC7`vHZ;w?WMLSjEfZR`%Q`42ziKG(L5Xd{N@3kT7%yYoBun zIYz&r2Yo()A7?&iIzL?l{p|+%-ZLkepxk6y8zsUgGVDy+W`TU1Z7NjTSHqE%YWcX< zTf{DL#n+EG(OlQNs4n;_8T00WJifCYm9+&!Lgol?tZgJcx64pN;8o<6{jn{AW~0=H zVQBoTDtbCK3gi~R3tBk!ISX^fwo~%f% zxk<`)>qvw7zI<-U1>BT&m<2}3Mrgaq%fNEBVa+U_`;la==P{Nj|hQoxrAqqo4 z27N4%gXEZ$9g{Q^wr$seXL+0=#Ig`-k50xTek&sGS-x!HrCgHoPaASoOKc69gV?ND z-jIE>ilO2C;6n3Ws;CKOq91lqpCLJznbyMQ%u~w4Iz}*SCoaGx19H)juI9S@ia$_m zf&-yXPT>au{6FSNV9(0!RB8nD_VIX{Fc#q7TftD=6G(*{%|!O-0GxbXMXvb|!l5g3 zL%(iEq{StVJa}C|cWtxCv(?eGJ>oX4oLG+InhKz)zLc`zB{;sYgM^;(XWvQ0FxbV@ zHupv$-Epswh;Iehm@PZ0m${UTlosKJM@68P?O~Id#qe_HV|=JR7|QnrQu@{)r{N!w z^e1oXqDNI@k>3jZJbVg1K0FT&V0Ty?u>fAIbioPj-f(8#UV84pdKh-LoYc+y2E|tx zu;|DF`I5?K=%94HJg6l_@vLbks9SSb_j5UlNo#zNHQyR%jTA%*dDmKH!)S icB|8Bi#aT`Kl=MH1%9%U> literal 0 HcmV?d00001 diff --git a/ocpmodels/datasets/embeddings/60_train_graphnorm_mean_4.pickle b/ocpmodels/datasets/embeddings/60_train_graphnorm_mean_4.pickle new file mode 100644 index 0000000000000000000000000000000000000000..c860751e7980d7d3e0cfcf6d99ecf1c38c877fdf GIT binary patch literal 1556 zcmZvcX;4#F6vq=tNC{8{MCuCCDtiP56h+>4FLa53LRF%b$PzUg2rqds1-GQE0u{w5 zf~ZskM8JwNP^-B&?zmLPs?$mU#ZeK*R>d8+KJ{zQhjZqh`@8>h?w7woyv5#5@G;69 zIjt^rxlEbGsWp02nTt}VO36}d(v+MkL$B2-vqMaqO+IB#pD;aNl9#DWWh0duZE~7& zsZN`&Ov&L?dY)odivfmdF@v^zUc(A}5=ZhQGj&={o2u0) zvsF61TARTNc_7h=ztN~UPNPz)GSbw^8LUllqLgoPIhiV@UY)Dbv$ktkkxwE&LZDx+ z)p4v?Zj)%spVS!~YbO^aiauy)GnTOuxhN}(pXrk*`2c@fnH8jd{DL%^Uz|($)ReNq z#LszVMrd$&Sa4WSX!wlK&>;4cZfa&0Ya8P!jOAF-TwZQ5EBHTHm@M-Y@)~@w3(T}@ zbr8x8To8FZXB*)jSWYfEKaZ9_!eF3FL3(s}V;B|$>S!x4zWxd3ZXCeDS?jRC+D;xG z>rR}T=1ou~L*QA72O4blrftjjgB{gfSZI9(-}U&zBA08ZH7T7IXTCts-}}O#6OsF#8bDP26!MZ=64m6YfSTBJDrHPHHD1z3K73hB z%)1Z(+b33{J1t)mL;Y=(?S?EcyyGQhnQlkSJ;;*#YerLh#M|%?`aoU<*7$pnhise~>rG7En;>Ky>P`^vqspi{ElGCd-bdF5nj_z7 zCsNg|Q==|4jF0SbZa0TG>Oj3sG0`VjOH0Wlkdh&yufGqXH@r4th42f@*{GjEY(XGp zEs+6ZM;^%I?1B46Ee7Vv)b@u-i0sRv>O|hS#d6mCL=a4GYFvZYhdcqvQU%%CWr4R3 zdDFQ)vmq;_3F9~Qc+@lnDC+f~8=nV6`brz1`t~~XtXlzNysi<32kpSOA|A|-7(zGN zoG1t8Gd!Wc3QjThgNR`%Fccu*{`*>DMGm6o4#dJ;ot~DY-tx$Wr4_iGc?xXp6yU_I zU~FG%&rHeBBlV66*tmqE7YS4G2$y8iag~G`svl{otO=&-o;85-rE{Ru4>{PY2_dR= zmE2XB4tLL6Is3>Gfj4mQ-&xGMWg0LUNF#`wQbde9T%c$^#jHP(*0TN92ebrWHC zleanKUteaE!bmEU(!ic*SMp56Bl=307am?N2Jf!4pw^g8XqM>`WwbP*=|SaKvr-iG zn(=jeRpB<_6%NgU;A|Z#;Q?cO&^dTL@)aMexx4N#y7z12rhVPd7iV1w>CC z>C|?cDox8I-W)*C8qtNz=h-vEj4h63CY#04rF4F+AMi~bcg$pP!2NNDv0L^Mres3_ z^+vOqK6JF2*%Y-IdBa;~*XaQ$Z1yhVa=JSnm=y)voEpvZ6HepRJ9}qQ3kiCOR1UYF zbGFon)S`X&Gnj&Eg$0@lKuzFP`Yo3O7QO7mouwA!D={)vl`G*dFc9W2ofMXF&ar;@3~yf=ji1%Gb_^nRDmX?4k| zsznCfGF6J1V~jjSSK0~$*}_c@+JB0o-*_%LG;?@Bws6K1Yd)`H1yZF8KT>bdaXPh5 ztIA*uMvX3=wc-J#n7`3#I8Mu`nDkUlayl#ApmgAyoLSGPjG7gUk+ojM+DMiB2!V02 z&cLy@3Zc@PKWWlAR-~{|+I-Z|rKhoW3LBG&pD9(^eS`;AW(Dd`UvP-#7w6^-G#ywg z$M9>Sb8$C;>^bh$_S)93Hv?Ud_=%Zkanw(0WEzc%(ts*t4jR`+;L) zDD^5o1F}0gxccThh|t?&`_!+HZHNGm_=jNggEQ(EM?=lgf8oO)hp=O2HuycuBWE6U zLd{ukyi_m-%G_eecj95x^d^PQOf7>atAkMRdI-$-oeDd`?vcC>HEa|8f=ePQ;6`&F zrg}s{SFRe5JkBDu2L{no$~e4lF+fkEFS^GQ*jT(2f*Pgda;y{%s*Xc|vWlMX?|{P> z&Cq+l5PhnG>8lN0B=G%poIGw18ZRCqA);O~eabvMkhGL^gxtXRDJJOgVbE350EL%I zp_y!j%#d~%Benv`&?DG2G#Rqq+QQtxBjn&Hf7rI;DNKwP(D9iOFm&SP`m8yQ+fUrYYrR(yC%uDKrvorMX){V=*TdfT*4XgiCJpRa zi4rk1eO+lG31{UW)s?X)v(%**I zT}Z{(^^}w=Fv@;Ay5erg4em!_ts5Nm+y`d+C|bR80rW_U;qCFC;Y83noN&bpEY*GF zc0nobdF}z$**nq6k_m_9Utrdub}U306zHF$pH7Lxa?YSYp9`Vg!Ps2J!1D(xcu%Tf zNNE!q-`0@K+ZS=E|5|eNr9U#G&J(kx8RvAm(DC0Flg|Dcd=^&#b-va#-@6_*o1c=P zk;lMX#YL)03$bB{2_{rVU{8e*yT6SAMUxPsw*E<40!=utJOSn9JE3|Hi#015D15q& zjG|khWc@g-uXTa5eS7iOGl?)UWes>uIt`MO=~!iPqonQlm6#jyG+e*%pp}74RQ!09@Yy literal 0 HcmV?d00001 diff --git a/ocpmodels/datasets/embeddings/60_train_graphnorm_var_1.pickle b/ocpmodels/datasets/embeddings/60_train_graphnorm_var_1.pickle new file mode 100644 index 0000000000000000000000000000000000000000..eddf48c6b17608674f1f58d4acad5127d58d6246 GIT binary patch literal 1556 zcmZuxYfw{H5Kc%iijhJQsZ=repwLPX0UwCo-E)JOW>7F%GrVFT5key&$qfQN>i9-e zj2d-7Ahn2!qoRVMU_~lm>l0MkDGw{sYE`sAaiSKPUiEMH$L^dx-+teoIWymS=e;9_ zao&eT$moqJX=0_B(dtY#%g0KiCfTgjsg;Z-!=yJVv!>hj*!(S{KEOl81*Za$vKS1#G=R=XAWn*({>l#VA1^7qV%INW5w&8qMz{Cyw^O41?1)g$vtcaxyhalQvglGVs1tsX3R2t;7!Ri8+g$o?n1`Ek7el+4V?c2b5&vym*rsN z5hX?P_#SL@EyTLV8F)`B#CYo-oIgcK?^qzlr1Tl^%L;&jT_UXX*)Q?yPXyV{6Cf2F z!tg;eIz8SAG0z*o+2Kl>gDW9!q8;)IRM?RhL=nRRsHf$v>Y3W&&y9jNc9^WoVOiyX*21D##cx=b<;0gELc-0fT6H&Fm`AJ zs+b=nzdxEw`H9jASJ6WC+NYCZ(oW>OW@re^!-PFE$%BD0xV+R2Z!6m2?&t-i&+(EJ z`umctXR>hA)y?2>+l4I6?f_$j7wy_nPS^LG#i~1YYEiwAj=0{3)s5SrzrBPkhc=`P z>p&P=1ttBV^ghu|6178w-UlPg?JGL{yiLfwguzF zq&8f>e>}!|$fpYi6*g39sHqJc zTsk-#Jm>NxJD#bb@0KT=X$r$Rxw}b&x0Fbh=0iZ1fPQMX;K8yS_OC|JZ)dI}hc*nr zrOSL0e2U;ud@;6N9Tqw6$pUgZ;WISnKB4w*I)ZU;>)~AOGAh!OB8^jgpy=Wpvc~m3 zyzW(F`r~u(slbX&E3J4XW(fvG9)xH0#n8XS8|}AtK<{xm^-!(A(d}~J)Stw{E;+i$ z!?3S32v;BQfzHGSa5*Xz-ZEAs(2AR@t8w>`1CAw^B2C@KfQ~LAw`P&T>?-*7X*fP_ z`Ui<|7`DXx1%<6v+Tk7v{zfj^{i~_OeJQl|#sreEYaVuP*pD8yMpALF4~1zVa729u zCC^J?>BXm{r@V#g806wc(H-m@?FoSkIWQ?ZkGk-Q1g|*UF~+qPadTLiuU!AN3F?tVEPT*WyvRvQ@x?C zW+yCZn@m@neuelf9lB1BpqGso8X&@zPxHk?Y@XLS2x1(AEmI|PfU0k zS2`?n9Xak%09kj0(9{baM*GV0s0V*#&f+ zbrY!X4pE_QaVVb~jH_R)f`_>ph$u?N3hy}RY&F4)l5VoXTZ)_PTj_Z@5%@A`3?5e9 z$1i0^&EvWEzQgf=J^8kV7Yj}HL-e$lb69fwA+HqSz0qsnLRf$MW)|zjS3AiSko}!fNRT)aX)*?&G_mpuX%=8Hw z9jmpJn9D4Vr@X8#BCqP<*soCX(zE=JmQ z@)p|D*@`Ue3B&)au)$tstvE??nAzJso8;GACIu005VFS}rSf(my7V%-GEjwzE+o)m zpGuszvlR`7m5`P|cf9!76#V+2N^Cv4nvNafLq_jzq%F%(LUSxsG zupA_qbqk-F&qKwVtg%z;2|7?%f<)^X=qSx0-IdEhXx)t?gUzXf^>6Xxycu{fsUM{K zPLR88W8tqO2Vt+jg50moAn9EZsIyLpML)M89A-yIS~Qlg*?_NXb0p82Y>{YY4SKN9 z7F0eXg0ZX>!lN>Y%kc`B^46X_e?Nsz?6ie}pozgnlj^CE%k$BjDK{af&`oTZI#qn` z=}WY_LxY;OtQ9}I@DeG)GD+V_XQ)`U6REA{p+;ZAjCygK)L&0zq8u-hHE&)peaG6V zF4rM6elwzo<`f$Ku>?xhHdsB)i25}C)X2LIunRYfi|HUHqSuCsw`gMq*%*NL0hCGf zgh{Kl!^bY4fwHS#qqfFEs7zRnLTd`~?SF;zfw@*h6h0NQmo9(<6C!a}@B(Cb{Oa1lA;6XF@O(OwS_Ux)>X9q3liSm?@VA(9pfN<%O%qTJEnUDF|d_i$4F$b_0#jK}4_cHsL< zf1-UB?!fD9&XCAI8N6|5KdQ*)gWKv+Xd-Ea{)sd2NZ&<>?evA{9t(0j>3ilweI9aK zwwS&Wv7L$H*~99@qv%daB^X7U#OGRnp`5bQs6Dqln7Dv<$o5noT;FR8A=5XaC2#wQ zckg~^>%EUdJiK9GRUH)iOu*YdCd1J&&Coy13N!1zLY00_Bu>7POmLCGnI}@neD{&8 zEtEn>{Y}Vq@rCTC`P7XCQM6xN4&}DD0cXW1nb|oW;;<8*)I-5`crIzcL4vDjSLh9T zT#Fv%UtLRfCA^~4V@~0KicR9Yyea_a>3B@jc`KX0^3mbKo6n!NSR_oFC~=*!RwA1=x|KL6Fh%|wB}BjEHuC|YgdS_I7E%UxdW-r zS5tGGXM%9%StJ}b7i~T#fLi+kCgp7by|nH)$vqi}`>rX7VPzFlka7gs&gzEFZWUSb z;vTNr&<~kU%yGuYV|3u=e8jJp;M)18Q3=0}+`eRv2izJ_r1M(bsdT;3md)%=5ueww0)M3&FS*KS;0$Vm zUX`XbnskO#R>%XCF8qyN$8mbCN}HF^2sd1Y3h{_hD+a(;1c-sh$h zD^$Ajn8>h@2uX+}I4m?`T6hHenOD8a%!+3F2xB)0-fG;e~XoQ*s)V>A=Lvjlr@ID$vcv4{0>V`1K`Wzl6ZLD=*t95XxJ%~l_S$OBYRQSpqgq?rZQ8NsM zxW#yj9G={P(j~6A-*+KVTa-%=eA5gAKaR#hbT@eC;RtUJuafPG&Bhz*TM3zHK2y|t z6y;2+#amUw;L$an*iBYUpx?g2{dfJCs74OczKdi-?v+!O4o6{?kA(T9rU}%a1Slyx zhdAk}#)U^xu(O|tIzsfw?I*qmqr{oS_K9zp<9V^9ra2p>z05*y-D2gmM>49eYaCP zn=*ji>^@5KOn{^Odcby{6#U!WY23U8!EA;sQkSeDQV23B8%MOU+>^X9r@UwcR7 zU7iO+$z%BC^l;Sfc>qY|{z#e3DvD2KwY&1KKCMnTEseNN89~?=SBK3Az4Q*!u@8_80@Vy(q##Wfo#~ z!BkLsxC3}4Jq2w+2l0!&b>MN~FyfT>618N@O47!4kbH3NDUiF_)31VxP*>bz@YAhZ z_}Q3u(uXM?%wTsdUiQ8M2SN zm2=h|wqJV<7Bg*13sUqvSO^KXVJ;QP4N2gOd@hu zmh?`G6E4gyK?>Il*dMVTODo0bPe(m`Qr-@r=6!*b-`%k=+2s@>O!dUhZ2{0L zV*(P{UZ%sEf@H>_ala)^j mzT{b)dArqX)g14(yu&-k8Z}y@RwL7?Io`Ab;;bA$CH@z&RB2@Z literal 0 HcmV?d00001 diff --git a/ocpmodels/datasets/embeddings/60_train_graphnorm_var_4.pickle b/ocpmodels/datasets/embeddings/60_train_graphnorm_var_4.pickle new file mode 100644 index 0000000000000000000000000000000000000000..d3a618823e058a24399173316f54191cb0d7a4c8 GIT binary patch literal 1556 zcmZvcX;c$e6vqQeP{SgMC@7pFSi!0Y0P=a1qFeK zhykq|C>FsbLKP4N)LOJhEkdz)T&V?=Mf6B5Dpu`SzxIB(_q_Xl?|<+6@=xOzj zsFZ6A`t{1#O(wOGrO?fM4kyhvN1*-ZpyZv`+VVy2{%N*UFPF_*7>=L9nH`yE(3|ve zdaW{9Z7^!|I>v?t6prkTR%0@0)k?KarHRoowmTF8wrSdws8$*^DQY9b-NNwv6zm9& zF+p!IF?_;S!DUYxor$p{cnaQU4ZSX&8A0%plGvGkiV>gTFDo;gxc|N&kg$uJmVRjp z7#oEX%ghfBoEH)p5)d3Xe{OIv^F=onHv3-2FAM}y2X-D6b?~C0-B%xp$Q$)!P&q=2Z>@xlZr!k7 zIt7a`=)mb$HZW3}MCB3!=(>DlxZ24;)w z1;}_Xiug{zC0myjAggLEK-a02)kjT&vN|V;Y%m&D4jaMqw?=B1yH2)%*h?mN=Y`ho z$%4_gB5)_l4R(#mL2X|a zS`cd@Vcl0$L9VD4Ra$+`AhMDai>UIHJBJ;bY4IdNPxyZlpme{|%L%5Df%7&7{v_ zl-#iVdx>=FEp&884w!zvhAz4WC{Ncyg5!FQBu>{O%F9KVwbz9DrKo>gOZg6%RZ;?MW`$7vvTaz# zgmkc(zlFNk>I`?a_=aC<-3^O`Lh^b}Kf=G+DDimHg2b91$_hMBpoaxs_>BRjIF_FZ zN>&6yNB3p&)X7fbbK^Tm;p@Y&Vxu2etz8Yq#s2~FQd(eom5PoI5K;Tz8ewbZLmXE* zi??SMlZO_(fTzYBksyYLNCE#xdXzK@Jy_l?P5MuF)A4E`O7sq|zBe8Z9%v#3bbeUgv7_YATODDgvlXnVwI}vAOr$-L zW~w*&PZ%^vk;RoL`b#bkRlR6Oo}C^Et`4VxlhKW#S4wx|Z*J$qfKdTpa+sC8bfFNg zD0@L3*y01FmPwNOb@{Y7>=+nGUx!TZ6Qf_B(ZZ_O8d?8|WoWqeCGqK?54f{vJ{9k< z7pNLHqlph9#JPih@CQ!CH zQUrb5JS*I~J{1i2T%mk9@l=(t8;d&613r(QB8Q&N#u{=Qz>rmjW;D;1a&K3mmhF}$ y7J@_Au&$M9VH3fYY_mOUw_0_aiFI4n;qB!HmD-?I;hH!TYueM6SWN6I|9=7ZGH$;B literal 0 HcmV?d00001 From 3d2b6b601cf02a5577c6ea6b313e94be7352305a Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Tue, 14 Jan 2025 14:43:49 -0500 Subject: [PATCH 33/45] load_datasets=True for testing --- configs/models/tasks/is2re.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/models/tasks/is2re.yaml b/configs/models/tasks/is2re.yaml index e237720aa6..c7b3def5b1 100644 --- a/configs/models/tasks/is2re.yaml +++ b/configs/models/tasks/is2re.yaml @@ -18,7 +18,7 @@ default: mode: train adsorbates: all # {"*O", "*OH", "*OH2", "*H"} adsorbates_ref_dir: /network/scratch/s/schmidtv/ocp/datasets/ocp/per_ads - load_datasets: False + load_datasets: True dataset: default_val: val_id train: From fca2d528071b81bf4a862207ee2a5182439f60da Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Tue, 14 Jan 2025 14:46:49 -0500 Subject: [PATCH 34/45] more comprehensive testing in the main of gfn.py --- ocpmodels/common/gfn.py | 59 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index a6f8fe4808..4e99da33fa 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -13,6 +13,9 @@ from ocpmodels.datasets.data_transforms import get_transforms from ocpmodels.modules.normalizer import Normalizer +import numpy as np +import matplotlib.pyplot as plt +import pickle class FAENetWrapper(nn.Module): def __init__( @@ -295,6 +298,13 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: return wrapper, loaders +def to_data_list(batch): + '''Better Batch.to_data_list() because it preserves the neighbors which sometimes + get dropped when using Batch.to_data_list() only''' + batch_to_list = batch.to_data_list() + for idx,item in enumerate(batch_to_list): + item.neighbors = batch.neighbors[idx] + return batch_to_list if __name__ == "__main__": # for instance in ipython: @@ -312,6 +322,49 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: } release = None wrapper, loaders = prepare_for_gfn(ckpt_paths, release) - data_gen = iter(loaders["train"]) - batch = next(data_gen) - preds = wrapper(batch) + + data_gen_ood_cat = iter(loaders["val_ood_cat"]) + data_gen_id = iter(loaders["val_id"]) + + print("Testing val ood cat 10 batches...") + batch_i = 0 + while batch := next(data_gen_ood_cat): + print(f"{batch_i=}") + if batch_i < 10: + preds_1 = wrapper(deepcopy(batch)).detach().cpu().numpy() + true_1 = np.array([b.y_relaxed for b in batch]).flatten() + print(f"Test batch {batch_i} val_ood_cat mae: {np.mean(np.abs(preds_1 - true_1))=}") + else: + break + batch_i += 1 + + print("Testing val id 10 batches...") + batch_i = 0 + while batch := next(data_gen_id): + print(f"{batch_i=}") + if batch_i < 10: + preds_1 = wrapper(deepcopy(batch)).detach().cpu().numpy() + true_1 = np.array([b.y_relaxed for b in batch]).flatten() + print(f"Test batch {batch_i} val_id mae: {np.mean(np.abs(preds_1 - true_1))=}") + else: + break + batch_i += 1 + + print("Testing whether the same samples within two different batches in fact give the same outputs...") + train_set_iterator = iter(loaders["train"]) + train_batch_0 = next(train_set_iterator) + train_batch_0 = train_batch_0[0] + first_5 = to_data_list(train_batch_0)[:5] + first_10 = to_data_list(train_batch_0)[:10] + batch_first_5 = [Batch.from_data_list(first_5)] + batch_first_10 = [Batch.from_data_list(first_10)] + + preds_batch_first_5 = wrapper(deepcopy(batch_first_5)).detach().cpu().numpy() + true_batch_first_5 = np.array([b.y_relaxed for b in batch_first_5]).flatten() + print(f"Test batch preds first 5: {preds_batch_first_5=}") + print(f"Test batch true first 5: {true_batch_first_5=}") + + preds_batch_first_10 = wrapper(deepcopy(batch_first_10)).detach().cpu().numpy() + true_batch_first_10 = np.array([b.y_relaxed for b in batch_first_10]).flatten() + print(f"Test batch preds first 10: {preds_batch_first_10=}") + print(f"Test batch true first 10: {true_batch_first_10=}") From 872763387f726ceaed34f10666db392e8df15cec Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Tue, 14 Jan 2025 14:48:09 -0500 Subject: [PATCH 35/45] if the mode is set to inference, set training to false --- ocpmodels/models/base_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ocpmodels/models/base_model.py b/ocpmodels/models/base_model.py index e2df0e7375..e301a36f4d 100644 --- a/ocpmodels/models/base_model.py +++ b/ocpmodels/models/base_model.py @@ -84,6 +84,8 @@ def forward(self, data, mode="train", regress_forces=None, q=None): data["catalyst"].pos.requires_grad_(True) else: data.pos.requires_grad_(True) + if mode == "inference": + self.training = False # predict energy preds = self.energy_forward(data, q=q) From afdf7824c39097fb13b483b5e6bac1712d0253db Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Tue, 14 Jan 2025 15:30:56 -0500 Subject: [PATCH 36/45] torch_scatter bug fixed --- ocpmodels/models/depfaenet.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ocpmodels/models/depfaenet.py b/ocpmodels/models/depfaenet.py index 4e83dbc0ec..df90d16da8 100644 --- a/ocpmodels/models/depfaenet.py +++ b/ocpmodels/models/depfaenet.py @@ -61,9 +61,8 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data): # We pool separately and then we concatenate. ads = self.current_tags == 2 cat = ~ads - - ads_out = scatter(h, batch * ads, dim=0, reduce="add") - cat_out = scatter(h, batch * cat, dim=0, reduce="add") + ads_out = scatter(h[ads,:], batch[ads], dim=0, reduce="add") + cat_out = scatter(h[cat,:], batch[cat], dim=0, reduce="add") if self.disconnected_mlp: ads_out = self.ads_lin(ads_out) From d1c5835a0ce0ee575900f3cf83d826cbf75f5d99 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Tue, 14 Jan 2025 15:59:55 -0500 Subject: [PATCH 37/45] added logic for graphnorm_inference, if statements are untested --- ocpmodels/models/faenet.py | 32 +++++-- ocpmodels/modules/graphnorm_inference.py | 109 +++++++++++++++++++++++ 2 files changed, 133 insertions(+), 8 deletions(-) create mode 100644 ocpmodels/modules/graphnorm_inference.py diff --git a/ocpmodels/models/faenet.py b/ocpmodels/models/faenet.py index 78b9980cd7..af24e52172 100644 --- a/ocpmodels/models/faenet.py +++ b/ocpmodels/models/faenet.py @@ -18,6 +18,7 @@ from ocpmodels.models.force_decoder import ForceDecoder from ocpmodels.models.utils.activations import swish from ocpmodels.modules.phys_embeddings import PhysEmbedding +from ocpmodels.modules.graphnorm_inference import GraphNormInference class GaussianSmearing(nn.Module): @@ -191,9 +192,14 @@ def __init__( self.graph_norm = graph_norm self.dropout_lin = float(dropout_lin) if graph_norm: - self.graph_norm = GraphNorm( - hidden_channels if "updown" not in self.mp_type else num_filters - ) + if not self.training: + self.graph_norm = GraphNormInference( + hidden_channels if "updown" not in self.mp_type else num_filters + ) + else: + self.graph_norm = GraphNorm( + hidden_channels if "updown" not in self.mp_type else num_filters + ) if self.mp_type == "simple": self.lin_h = nn.Linear(hidden_channels, hidden_channels) @@ -240,7 +246,7 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.lin_h.weight) self.lin_h.bias.data.fill_(0) - def forward(self, h, edge_index, e): + def forward(self, h, edge_index, e,ib): # Define edge embedding if self.dropout_lin > 0: @@ -264,7 +270,10 @@ def forward(self, h, edge_index, e): h = self.act(self.lin_down(h)) # downscale node rep. h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - h = self.act(self.graph_norm(h)) + if self.training: + h = self.act(self.graph_norm(h)) + else: + h = self.act(self.graph_norm(h, ib)) h = F.dropout( h, p=self.dropout_lin, training=self.training or self.deup_inference ) @@ -279,7 +288,10 @@ def forward(self, h, edge_index, e): e = self.lin_geom(e) h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - h = self.act(self.graph_norm(h)) + if self.training: + h = self.act(self.graph_norm(h)) + else: + h = self.act(self.graph_norm(h, ib)) h = torch.cat((h, chi), dim=1) h = F.dropout( h, p=self.dropout_lin, training=self.training or self.deup_inference @@ -289,7 +301,10 @@ def forward(self, h, edge_index, e): elif self.mp_type in {"base", "simple"}: h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - h = self.act(self.graph_norm(h)) + if self.training: + h = self.act(self.graph_norm(h)) + else: + h = self.act(self.graph_norm(h, ib)) h = F.dropout( h, p=self.dropout_lin, training=self.training or self.deup_inference ) @@ -739,7 +754,8 @@ def energy_forward(self, data, q=None): self.first_trainable_layer.split("_")[1] ): q = h.clone().detach() - h = h + interaction(h, edge_index, e) + h = h + interaction(h, edge_index, e, ib) + # Atom skip-co if self.skip_co == "concat_atom": diff --git a/ocpmodels/modules/graphnorm_inference.py b/ocpmodels/modules/graphnorm_inference.py new file mode 100644 index 0000000000..7ddb4b32cf --- /dev/null +++ b/ocpmodels/modules/graphnorm_inference.py @@ -0,0 +1,109 @@ +from typing import Optional + +import torch +from torch import Tensor + +from torch_geometric.nn.inits import ones, zeros +from torch_geometric.typing import OptTensor +from torch_geometric.utils import scatter + +import pickle +import matplotlib.pyplot as plt +import numpy as np + + +class GraphNormInference(torch.nn.Module): + r"""Applies graph normalization over individual graphs as described in the + `"GraphNorm: A Principled Approach to Accelerating Graph Neural Network + Training" `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot + \textrm{E}[\mathbf{x}]} + {\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]] + + \epsilon}} \odot \gamma + \beta + + where :math:`\alpha` denotes parameters that learn how much information + to keep in the mean. + + Args: + in_channels (int): Size of each input sample. + eps (float, optional): A value added to the denominator for numerical + stability. (default: :obj:`1e-5`) + """ + def __init__(self, in_channels: int, eps: float = 1e-5): + super().__init__() + + self.in_channels = in_channels + self.eps = eps + + self.weight = torch.nn.Parameter(torch.empty(in_channels)) + self.bias = torch.nn.Parameter(torch.empty(in_channels)) + self.mean_scale = torch.nn.Parameter(torch.empty(in_channels)) + + self.reset_parameters() + + self.training_means = {} + self.training_vars = {} + for interaction_block_idx in range(5): + with open(f'ocpmodels/datasets/embeddings/60_train_graphnorm_mean_{interaction_block_idx}.pickle', 'rb') as handle: + self.training_means[interaction_block_idx] = pickle.load(handle) + with open(f'ocpmodels/datasets/embeddings/60_train_graphnorm_var_{interaction_block_idx}.pickle', 'rb') as handle: + self.training_vars[interaction_block_idx] = pickle.load(handle) + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + ones(self.weight) + zeros(self.bias) + ones(self.mean_scale) + + + def forward(self, x: Tensor, idx: Optional[int]=0, batch: OptTensor = None, + batch_size: Optional[int] = None) -> Tensor: + r"""Forward pass. + + Args: + x (torch.Tensor): The source tensor. + batch (torch.Tensor, optional): The batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns + each element to a specific example. (default: :obj:`None`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) + """ + # print("[graphnorm_inference] in forward") + if batch is None: + batch = x.new_zeros(x.size(0), dtype=torch.long) + batch_size = 1 + + if batch_size is None: + batch_size = int(batch.max()) + 1 + + # f = open("ocpmodels/datasets/embeddings/graphnorm_info.txt", "r") + # s = int(f.read()) + # print(f"{s=}") + mean = self.training_means[idx] + var = self.training_vars[idx] + + # mean_from_batch = scatter(x, batch, 0, batch_size, reduce='mean') + # with open(f'ocpmodels/datasets/embeddings/train_batch_0{s+1}_graphnorm_mean_{idx}.pickle', 'wb') as handle: + # pickle.dump(mean_from_batch, handle, protocol=pickle.HIGHEST_PROTOCOL) + # print(f"mean_per_batch: {torch.linalg.norm(mean_per_batch.flatten(),ord=torch.inf)=}") + # print(f"mean_from_pickle: {torch.linalg.norm(mean.flatten(),ord=torch.inf)=}") + out = x - mean.index_select(0, batch) * self.mean_scale + + # var_from_batch = scatter(out.pow(2), batch, 0, batch_size, reduce='mean') + # with open(f'ocpmodels/datasets/embeddings/train_batch_0{s+1}_graphnorm_var_{idx}.pickle', 'wb') as handle: + # pickle.dump(var_from_batch, handle, protocol=pickle.HIGHEST_PROTOCOL) + # print(f"var_per_batch: {torch.linalg.norm(var_per_batch.flatten(),ord=torch.inf)=}") + # print(f"var_from_pickle: {torch.linalg.norm(var.flatten(),ord=torch.inf)=}") + std = (var + self.eps).sqrt().index_select(0, batch) + + # if idx == 4: + # s += 1 + # with open("ocpmodels/datasets/embeddings/graphnorm_info.txt", 'w') as f: + # f.write('%d' % s) + return self.weight * out / std + self.bias + + + def __repr__(self): + return f'{self.__class__.__name__}({self.in_channels})' \ No newline at end of file From 38405fd142d65b9c5f86bd3771db2a58f1741c98 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Mon, 31 Mar 2025 12:06:27 -0400 Subject: [PATCH 38/45] added some testing code --- ocpmodels/common/gfn.py | 85 ++++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 31 deletions(-) diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 4e99da33fa..9c7bec28c4 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -172,13 +172,10 @@ def forward( if retrieve_hidden: return preds - # Denormalize predictions preds["energy"] = self.normalizers["target"].denorm( preds["energy"], ) - # preds["energy"] = preds["energy"].to(torch.float16) - return preds["energy"] def freeze(self): @@ -322,49 +319,75 @@ def to_data_list(batch): } release = None wrapper, loaders = prepare_for_gfn(ckpt_paths, release) + wrapper.eval() data_gen_ood_cat = iter(loaders["val_ood_cat"]) + data_gen_ood_both = iter(loaders["val_ood_both"]) + data_gen_ood_ads = iter(loaders["val_ood_ads"]) data_gen_id = iter(loaders["val_id"]) + train_set_iterator = iter(loaders["train"]) - print("Testing val ood cat 10 batches...") - batch_i = 0 - while batch := next(data_gen_ood_cat): - print(f"{batch_i=}") - if batch_i < 10: - preds_1 = wrapper(deepcopy(batch)).detach().cpu().numpy() - true_1 = np.array([b.y_relaxed for b in batch]).flatten() - print(f"Test batch {batch_i} val_ood_cat mae: {np.mean(np.abs(preds_1 - true_1))=}") - else: - break - batch_i += 1 - - print("Testing val id 10 batches...") + print("Testing batches...") batch_i = 0 - while batch := next(data_gen_id): + while batch := next(data_gen_ood_both): print(f"{batch_i=}") - if batch_i < 10: - preds_1 = wrapper(deepcopy(batch)).detach().cpu().numpy() - true_1 = np.array([b.y_relaxed for b in batch]).flatten() - print(f"Test batch {batch_i} val_id mae: {np.mean(np.abs(preds_1 - true_1))=}") - else: - break + preds_1 = wrapper(deepcopy(batch)).detach().cpu().numpy() + true_1 = np.array([b.y_relaxed for b in batch]).flatten() + print(f"Test batch {batch_i} mae: {np.mean(np.abs(preds_1 - true_1))=}") batch_i += 1 + exit(1) + + # print("Testing val id 10 batches...") + # batch_i = 0 + # while batch := next(data_gen_ood_ads): + # print(f"{batch_i=}") + # if batch_i < 5: + # preds_1 = wrapper(deepcopy(batch)).detach().cpu().numpy().flatten() + # true_1 = np.array([b.y_relaxed for b in batch]).flatten() + # plt.plot(true_1,preds_1,'o',label=f'Batch {batch_i}') + # plt.plot(true_1,true_1,c='r') + # plt.legend() + # print(f"Test batch {batch_i} val_id mae: {np.mean(np.abs(preds_1 - true_1))=}") + # else: + # break + # batch_i += 1 + # plt.savefig('val_ood_ads_depfaenet') + # exit(1) print("Testing whether the same samples within two different batches in fact give the same outputs...") - train_set_iterator = iter(loaders["train"]) - train_batch_0 = next(train_set_iterator) + + train_batch_0 = next(data_gen_ood_ads) + # with open('train_batch_0.pickle', 'wb') as handle: + # pickle.dump(train_batch_0, handle, protocol=pickle.HIGHEST_PROTOCOL) train_batch_0 = train_batch_0[0] - first_5 = to_data_list(train_batch_0)[:5] - first_10 = to_data_list(train_batch_0)[:10] + first_5 = to_data_list(train_batch_0)[:64] + first_10 = to_data_list(train_batch_0)[:128] batch_first_5 = [Batch.from_data_list(first_5)] + print(f"{batch_first_5=}") batch_first_10 = [Batch.from_data_list(first_10)] + print(f"{batch_first_10=}") + # with open('batch_first_10.pickle', 'wb') as handle: + # pickle.dump(first_10, handle, protocol=pickle.HIGHEST_PROTOCOL) + # exit(1) + print("First 5") preds_batch_first_5 = wrapper(deepcopy(batch_first_5)).detach().cpu().numpy() true_batch_first_5 = np.array([b.y_relaxed for b in batch_first_5]).flatten() - print(f"Test batch preds first 5: {preds_batch_first_5=}") - print(f"Test batch true first 5: {true_batch_first_5=}") + # print(f"Test batch preds first 1: {preds_batch_first_5=}") + # print(f"Test batch true first 1: {true_batch_first_5=}") + print(f"Test mae: {np.mean(np.abs(preds_batch_first_5 - true_batch_first_5))=}") + print("First 10") preds_batch_first_10 = wrapper(deepcopy(batch_first_10)).detach().cpu().numpy() true_batch_first_10 = np.array([b.y_relaxed for b in batch_first_10]).flatten() - print(f"Test batch preds first 10: {preds_batch_first_10=}") - print(f"Test batch true first 10: {true_batch_first_10=}") + # print(f"Test batch preds first 2: {preds_batch_first_10=}") + # print(f"Test batch true first 2: {true_batch_first_10=}") + print(f"Test mae: {np.mean(np.abs(preds_batch_first_10 - true_batch_first_10))=}") + + + # print("First 256") + # print(f"{train_batch_0=}") + # preds_train_batch_0 = wrapper(deepcopy(train_batch_0)).detach().cpu().numpy() + # true_train_batch_0 = np.array([b.y_relaxed for b in train_batch_0]).flatten() + # print(f"Test batch preds first 2: {preds_train_batch_0=}") + # print(f"Test batch true first 2: {true_train_batch_0=}") From 31d3e046963ddae194dee10187684a0786a1d5c0 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Mon, 31 Mar 2025 12:10:37 -0400 Subject: [PATCH 39/45] removed hacky graphnorm fixes --- ocpmodels/models/faenet.py | 33 ++++++++------------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/ocpmodels/models/faenet.py b/ocpmodels/models/faenet.py index af24e52172..e4737e1df8 100644 --- a/ocpmodels/models/faenet.py +++ b/ocpmodels/models/faenet.py @@ -18,7 +18,6 @@ from ocpmodels.models.force_decoder import ForceDecoder from ocpmodels.models.utils.activations import swish from ocpmodels.modules.phys_embeddings import PhysEmbedding -from ocpmodels.modules.graphnorm_inference import GraphNormInference class GaussianSmearing(nn.Module): @@ -192,14 +191,9 @@ def __init__( self.graph_norm = graph_norm self.dropout_lin = float(dropout_lin) if graph_norm: - if not self.training: - self.graph_norm = GraphNormInference( - hidden_channels if "updown" not in self.mp_type else num_filters - ) - else: - self.graph_norm = GraphNorm( - hidden_channels if "updown" not in self.mp_type else num_filters - ) + self.graph_norm = GraphNorm( + hidden_channels if "updown" not in self.mp_type else num_filters + ) if self.mp_type == "simple": self.lin_h = nn.Linear(hidden_channels, hidden_channels) @@ -246,7 +240,7 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.lin_h.weight) self.lin_h.bias.data.fill_(0) - def forward(self, h, edge_index, e,ib): + def forward(self, h, edge_index, e,batch=None): # Define edge embedding if self.dropout_lin > 0: @@ -270,10 +264,7 @@ def forward(self, h, edge_index, e,ib): h = self.act(self.lin_down(h)) # downscale node rep. h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - if self.training: - h = self.act(self.graph_norm(h)) - else: - h = self.act(self.graph_norm(h, ib)) + h = self.act(self.graph_norm(h,batch=batch)) h = F.dropout( h, p=self.dropout_lin, training=self.training or self.deup_inference ) @@ -288,10 +279,7 @@ def forward(self, h, edge_index, e,ib): e = self.lin_geom(e) h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - if self.training: - h = self.act(self.graph_norm(h)) - else: - h = self.act(self.graph_norm(h, ib)) + h = self.act(self.graph_norm(h,batch=batch)) h = torch.cat((h, chi), dim=1) h = F.dropout( h, p=self.dropout_lin, training=self.training or self.deup_inference @@ -301,10 +289,7 @@ def forward(self, h, edge_index, e,ib): elif self.mp_type in {"base", "simple"}: h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - if self.training: - h = self.act(self.graph_norm(h)) - else: - h = self.act(self.graph_norm(h, ib)) + h = self.act(self.graph_norm(h,batch=batch)) h = F.dropout( h, p=self.dropout_lin, training=self.training or self.deup_inference ) @@ -729,7 +714,6 @@ def energy_forward(self, data, q=None): if not hasattr(data, "deup_q"): # Embedding block h, e = self.embed_block(z, rel_pos, edge_attr, data.tags) - if "inter" and "0" in self.first_trainable_layer: q = h.clone().detach() @@ -738,7 +722,6 @@ def energy_forward(self, data, q=None): alpha = self.w_lin(h) else: alpha = None - # Interaction blocks energy_skip_co = [] for ib, interaction in enumerate(self.interaction_blocks): @@ -754,7 +737,7 @@ def energy_forward(self, data, q=None): self.first_trainable_layer.split("_")[1] ): q = h.clone().detach() - h = h + interaction(h, edge_index, e, ib) + h = h + interaction(h, edge_index, e, batch) # Atom skip-co From 7098c045674e983a8f0820971ca3b42cb87d443e Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Wed, 9 Apr 2025 14:26:07 -0400 Subject: [PATCH 40/45] changed some normalization things, added oc22 configs/paths --- configs/exps/deup/gnn/pretrain-depfaenet.yaml | 27 ++++++++++++ configs/models/depfaenet.yaml | 42 +++++++++++++++++++ configs/models/tasks/is2re.yaml | 16 +++---- ocpmodels/common/gfn.py | 10 ++--- ocpmodels/common/utils.py | 15 ++++++- ocpmodels/models/faenet.py | 4 +- ocpmodels/modules/evaluator.py | 2 +- ocpmodels/trainers/base_trainer.py | 4 +- 8 files changed, 100 insertions(+), 20 deletions(-) diff --git a/configs/exps/deup/gnn/pretrain-depfaenet.yaml b/configs/exps/deup/gnn/pretrain-depfaenet.yaml index 83029997d4..27b7c57740 100644 --- a/configs/exps/deup/gnn/pretrain-depfaenet.yaml +++ b/configs/exps/deup/gnn/pretrain-depfaenet.yaml @@ -54,6 +54,33 @@ runs: - config: depfaenet-is2re-all note: depfaenet with top configs + dropout + model: + mp_type: updownscale_base + phys_embeds: True + tag_hidden_channels: 32 + pg_hidden_channels: 96 + energy_head: weighted-av-final-embeds + complex_mp: True + graph_norm: True + hidden_channels: 352 + num_filters: 288 + num_gaussians: 68 + num_interactions: 5 + second_layer_MLP: False + skip_co: False + cutoff: 4.0 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.002 + scheduler: LinearWarmupCosineAnnealingLR + max_epochs: 9 + eval_every: 0.4 + + + + - config: depfaenet-is2re_oc22-all + note: finetune depfaenet on oc22 model: mp_type: updownscale_base phys_embeds: True diff --git a/configs/models/depfaenet.yaml b/configs/models/depfaenet.yaml index da19d6e350..d382a00442 100644 --- a/configs/models/depfaenet.yaml +++ b/configs/models/depfaenet.yaml @@ -98,6 +98,48 @@ is2re: warmup_steps: 6000 max_epochs: 20 +is2re_oc22: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + 100k: + model: + hidden_channels: 256 + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 18000 + - 27000 + - 37000 + warmup_steps: 6000 + max_epochs: 20 # ------------------ # ----- S2EF ----- # ------------------ diff --git a/configs/models/tasks/is2re.yaml b/configs/models/tasks/is2re.yaml index c7b3def5b1..59f08658c2 100644 --- a/configs/models/tasks/is2re.yaml +++ b/configs/models/tasks/is2re.yaml @@ -4,7 +4,7 @@ default: task: dataset: single_point_lmdb - description: "Relaxed state energy prediction from initial structure." + description: "Relaxed state energy prediction from initial structure OC22." type: regression metric: mae labels: @@ -22,18 +22,18 @@ default: dataset: default_val: val_id train: - src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/train/ + src: /network/projects/crystalgfn/catalyst/ocp/oc22/train/ normalize_labels: True - target_mean: -1.525913953781128 - target_std: 2.279365062713623 + target_mean: 0.38994525473291336 # -1.525913953781128 + target_std: 2.524972595834097 # 2.279365062713623 val_id: - src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_id/ + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_id/ val_ood_cat: - src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_cat/ + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ val_ood_ads: - src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_ads/ + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ val_ood_both: - src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_both/ + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ # DEUP deup_dataset: create: False # "before" -> created before training (for deup) "after" -> created after training (for is2re) "" - not created diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 9c7bec28c4..4e6a95a489 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -313,7 +313,7 @@ def to_data_list(batch): release = "0.0.1" # or ckpt_paths = { - "mila": "/network/scratch/a/alexandre.duval/ocp/catalyst-ckpts/0.0.1/best_checkpoint.pt", + "mila": "/network/projects/crystalgfn/catalyst/checkpoints/best_checkpoint_depfaenet.pt", "lpodina": "/home/felixt/shared/checkpoints/best_checkpoint.pt", "narval": "/home/felixt/shared/checkpoints/best_checkpoint.pt" } @@ -321,15 +321,15 @@ def to_data_list(batch): wrapper, loaders = prepare_for_gfn(ckpt_paths, release) wrapper.eval() - data_gen_ood_cat = iter(loaders["val_ood_cat"]) - data_gen_ood_both = iter(loaders["val_ood_both"]) - data_gen_ood_ads = iter(loaders["val_ood_ads"]) + # data_gen_ood_cat = iter(loaders["val_ood_cat"]) + # data_gen_ood_both = iter(loaders["val_ood_both"]) + # data_gen_ood_ads = iter(loaders["val_ood_ads"]) data_gen_id = iter(loaders["val_id"]) train_set_iterator = iter(loaders["train"]) print("Testing batches...") batch_i = 0 - while batch := next(data_gen_ood_both): + while batch := next(train_set_iterator): print(f"{batch_i=}") preds_1 = wrapper(deepcopy(batch)).detach().cpu().numpy() true_1 = np.array([b.y_relaxed for b in batch]).flatten() diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index b7a39d391b..76711c0641 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -167,7 +167,7 @@ def move_lmdb_data_to_slurm_tmpdir(trainer_config): print("\nšŸš‰ Copying data to slurm tmpdir", flush=True) - tmp_dir = os.environ.get("SLURM_TMPDIR") or f"/Tmp/slurm.{JOB_ID}.0" + tmp_dir = os.environ.get("SLURM_TMPDIR") or f"/tmp" tmp_dir = Path(tmp_dir) for s, split in trainer_config["dataset"].items(): if not isinstance(split, dict): @@ -1220,6 +1220,16 @@ def build_config(args, args_override=[], dict_overrides={}, silent=None): if isinstance(v, dict) and "src" in v } ) + target_mean_std = copy.deepcopy( + { + k: { + "target_mean": v["target_mean"], + "target_std": v["target_std"] + } # keep original src, if data was moved in the resumed exp + for k, v in config["dataset"].items() + if isinstance(v, dict) and "target_mean" in v + } + ) # override new config with loaded config config = merge_dicts(config, loaded_config) # set new dirs back @@ -1227,8 +1237,9 @@ def build_config(args, args_override=[], dict_overrides={}, silent=None): config, {k: resolve(v) if isinstance(v, (str, Path)) else v for k, v in new_dirs}, ) - # set new data sources back + # set new data sources and target mean/std back config["dataset"] = merge_dicts(config["dataset"], data_srcs) + config["dataset"] = merge_dicts(config["dataset"], target_mean_std) # parse overriding command-line args cli = cli_args_dict() # check max steps/epochs diff --git a/ocpmodels/models/faenet.py b/ocpmodels/models/faenet.py index e4737e1df8..8d31270bb8 100644 --- a/ocpmodels/models/faenet.py +++ b/ocpmodels/models/faenet.py @@ -157,8 +157,8 @@ def forward(self, z, rel_pos, edge_attr, tag=None, subnodes=None): # Concat period & group embedding if self.use_pg: - h_period = self.period_embedding(self.phys_emb.period[z]) - h_group = self.group_embedding(self.phys_emb.group[z]) + h_period = self.period_embedding(self.phys_emb.period[z] - 1) + h_group = self.group_embedding(self.phys_emb.group[z] - 1) h = torch.cat((h, h_period, h_group), dim=1) # MLP diff --git a/ocpmodels/modules/evaluator.py b/ocpmodels/modules/evaluator.py index 591a3cac20..359039bbe0 100644 --- a/ocpmodels/modules/evaluator.py +++ b/ocpmodels/modules/evaluator.py @@ -87,7 +87,7 @@ class Evaluator: } def __init__(self, task=None, model_regresses_forces=""): - assert task in ["s2ef", "is2rs", "is2re", "qm9", "qm7x", "deup_is2re"] + assert task in ["s2ef", "is2rs", "is2re", "qm9", "qm7x", "deup_is2re", "is2re_oc22"] self.task = task self.metric_fn = self.task_metrics[task] diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index a6710a1d12..bf1b891cca 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -551,8 +551,8 @@ def load_checkpoint(self, checkpoint_path, silent=False): self.ema = None for key in checkpoint["normalizers"]: - if key in self.normalizers: - self.normalizers[key].load_state_dict(checkpoint["normalizers"][key]) + # if key in self.normalizers: + # self.normalizers[key].load_state_dict(checkpoint["normalizers"][key]) if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) From 34fc831447aa0e22bf004b8aeae9f0798c4faf44 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Tue, 15 Apr 2025 21:51:49 -0400 Subject: [PATCH 41/45] new configs and changes for those configs, train oc20 from scratch --- configs/models/depfaenet.yaml | 43 ++++++++++++++++++++++ configs/models/tasks/is2re_oc20.yaml | 53 ++++++++++++++++++++++++++++ configs/models/tasks/is2re_oc22.yaml | 53 ++++++++++++++++++++++++++++ ocpmodels/modules/evaluator.py | 10 +++++- ocpmodels/trainers/single_trainer.py | 4 +-- 5 files changed, 160 insertions(+), 3 deletions(-) create mode 100644 configs/models/tasks/is2re_oc20.yaml create mode 100644 configs/models/tasks/is2re_oc22.yaml diff --git a/configs/models/depfaenet.yaml b/configs/models/depfaenet.yaml index d382a00442..6cba51a7ad 100644 --- a/configs/models/depfaenet.yaml +++ b/configs/models/depfaenet.yaml @@ -140,6 +140,49 @@ is2re_oc22: - 37000 warmup_steps: 6000 max_epochs: 20 + +is2re_oc20: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + 100k: + model: + hidden_channels: 256 + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 18000 + - 27000 + - 37000 + warmup_steps: 6000 + max_epochs: 20 # ------------------ # ----- S2EF ----- # ------------------ diff --git a/configs/models/tasks/is2re_oc20.yaml b/configs/models/tasks/is2re_oc20.yaml new file mode 100644 index 0000000000..288f119ab5 --- /dev/null +++ b/configs/models/tasks/is2re_oc20.yaml @@ -0,0 +1,53 @@ +default: + trainer: single + logger: wandb + + task: + dataset: single_point_lmdb + description: "Relaxed state energy prediction from initial structure OC22." + type: regression + metric: mae + labels: + - relaxed energy + optim: + optimizer: AdamW + normalizer: null + model: + otf_graph: False + max_num_neighbors: 40 + mode: train + adsorbates: all # {"*O", "*OH", "*OH2", "*H"} + adsorbates_ref_dir: /network/scratch/s/schmidtv/ocp/datasets/ocp/per_ads + load_datasets: True + dataset: + default_val: val_id + train: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/train/ + normalize_labels: True + target_mean: -1.525913953781128 + target_std: 2.279365062713623 + val_id: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_id/ + val_ood_cat: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ + val_ood_ads: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ + val_ood_both: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ + # DEUP + deup_dataset: + create: False # "before" -> created before training (for deup) "after" -> created after training (for is2re) "" - not created + dataset_strs: ["train", "val_id", "val_ood_cat", "val_ood_ads"] + n_samples: 7 + +10k: + dataset: + train: + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/10k/train + +100k: + dataset: + train: + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/100k/train + +all: {} \ No newline at end of file diff --git a/configs/models/tasks/is2re_oc22.yaml b/configs/models/tasks/is2re_oc22.yaml new file mode 100644 index 0000000000..288f119ab5 --- /dev/null +++ b/configs/models/tasks/is2re_oc22.yaml @@ -0,0 +1,53 @@ +default: + trainer: single + logger: wandb + + task: + dataset: single_point_lmdb + description: "Relaxed state energy prediction from initial structure OC22." + type: regression + metric: mae + labels: + - relaxed energy + optim: + optimizer: AdamW + normalizer: null + model: + otf_graph: False + max_num_neighbors: 40 + mode: train + adsorbates: all # {"*O", "*OH", "*OH2", "*H"} + adsorbates_ref_dir: /network/scratch/s/schmidtv/ocp/datasets/ocp/per_ads + load_datasets: True + dataset: + default_val: val_id + train: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/train/ + normalize_labels: True + target_mean: -1.525913953781128 + target_std: 2.279365062713623 + val_id: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_id/ + val_ood_cat: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ + val_ood_ads: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ + val_ood_both: + src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ + # DEUP + deup_dataset: + create: False # "before" -> created before training (for deup) "after" -> created after training (for is2re) "" - not created + dataset_strs: ["train", "val_id", "val_ood_cat", "val_ood_ads"] + n_samples: 7 + +10k: + dataset: + train: + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/10k/train + +100k: + dataset: + train: + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/100k/train + +all: {} \ No newline at end of file diff --git a/ocpmodels/modules/evaluator.py b/ocpmodels/modules/evaluator.py index 359039bbe0..13d4573e31 100644 --- a/ocpmodels/modules/evaluator.py +++ b/ocpmodels/modules/evaluator.py @@ -53,6 +53,11 @@ class Evaluator: "energy_mse", "energy_within_threshold", ], + "is2re_oc20": [ + "energy_mae", + "energy_mse", + "energy_within_threshold", + ], "qm9": [ "energy_mae", "energy_mse", @@ -72,6 +77,7 @@ class Evaluator: "s2ef": ["energy", "forces", "natoms"], "is2rs": ["positions", "cell", "pbc", "natoms"], "is2re": ["energy"], + "is2re_oc20": ["energy"], "qm9": ["energy"], "qm7x": ["energy"], "deup_is2re": ["energy"], @@ -81,13 +87,15 @@ class Evaluator: "s2ef": "energy_force_within_threshold", "is2rs": "average_distance_within_threshold", "is2re": "energy_mae", + "is2re_oc20": "energy_mae", + "is2re_oc22": "energy_mae", "qm9": "energy_mae", "qm7x": "energy_mae", "deup_is2re": "energy_mse", } def __init__(self, task=None, model_regresses_forces=""): - assert task in ["s2ef", "is2rs", "is2re", "qm9", "qm7x", "deup_is2re", "is2re_oc22"] + assert task in ["s2ef", "is2rs", "is2re", "qm9", "qm7x", "deup_is2re", "is2re_oc22", "is2re_oc20"] self.task = task self.metric_fn = self.task_metrics[task] diff --git a/ocpmodels/trainers/single_trainer.py b/ocpmodels/trainers/single_trainer.py index b9e2f921cf..c3099d0429 100644 --- a/ocpmodels/trainers/single_trainer.py +++ b/ocpmodels/trainers/single_trainer.py @@ -595,7 +595,7 @@ def compute_loss(self, preds, batch_list): [ ( batch.y_relaxed.to(self.device) - if self.task_name == "is2re" + if "is2re" in self.task_name else ( batch.deup_loss.to(self.device) if self.task_name == "deup_is2re" @@ -716,7 +716,7 @@ def compute_metrics( [ ( batch.y_relaxed.to(self.device) - if self.task_name == "is2re" + if "is2re" in self.task_name else ( batch.deup_loss.to(self.device) if self.task_name == "deup_is2re" From f6b896e6e284beecc10ec1e7300be128594c44b2 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Thu, 17 Apr 2025 00:14:31 -0400 Subject: [PATCH 42/45] is2re_oc20 yaml changed --- configs/models/tasks/is2re_oc20.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/models/tasks/is2re_oc20.yaml b/configs/models/tasks/is2re_oc20.yaml index 288f119ab5..1d9257bfd3 100644 --- a/configs/models/tasks/is2re_oc20.yaml +++ b/configs/models/tasks/is2re_oc20.yaml @@ -22,18 +22,18 @@ default: dataset: default_val: val_id train: - src: /network/projects/crystalgfn/catalyst/ocp/oc22/train/ + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/train/ normalize_labels: True target_mean: -1.525913953781128 target_std: 2.279365062713623 val_id: - src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_id/ + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_id/ val_ood_cat: - src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_cat/ val_ood_ads: - src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_ads/ val_ood_both: - src: /network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/ + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_both/ # DEUP deup_dataset: create: False # "before" -> created before training (for deup) "after" -> created after training (for is2re) "" - not created From 35a59ee9fe70445728321c1655092c9e33cea061 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Thu, 17 Apr 2025 00:21:32 -0400 Subject: [PATCH 43/45] merge gfn.py --- ocpmodels/common/gfn.py | 72 ++++++++++++++++++++++++++++++++++------- 1 file changed, 61 insertions(+), 11 deletions(-) diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 4e6a95a489..6e709b7c74 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -171,11 +171,14 @@ def forward( if retrieve_hidden: return preds - # Denormalize predictions + # self.normalizers["target"].mean = 0.38994525473291336 + # self.normalizers["target"].std = 2.524972595834097 + preds["energy"] = self.normalizers["target"].denorm( preds["energy"], ) + return preds["energy"] def freeze(self): @@ -244,8 +247,24 @@ def find_ckpt(ckpt_paths: dict, release: str) -> Path: ) return ckpts[0] - -def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: +def get_test_dataset_configs(name): + if name == "OC22": + return {'default_val': + 'val_id', + 'train': {'src': + '/network/projects/crystalgfn/catalyst/ocp/oc22/train/', + 'normalize_labels': True, 'target_mean': -1.525913953781128, 'target_std': 2.279365062713623, 'split': 'all'}, + 'val_id': {'src': '/network/projects/crystalgfn/catalyst/ocp/oc22/val_id/', 'split': 'all'}, + 'val_ood_cat': {'src': '/network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/', 'split': 'all'}, + 'val_ood_ads': {'src': '/network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/', 'split': 'all'}, + 'val_ood_both': {'src': '/network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/', 'split': 'all'}, + 'val_ood': {'src': '/network/projects/crystalgfn/catalyst/ocp/oc22/val_ood/', + 'split': 'all'}} + elif name == "OC20": + return {'default_val': 'val_id', 'train': {'src': '/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/train/', 'normalize_labels': True, 'target_mean': -1.525913953781128, 'target_std': 2.279365062713623, 'split': 'all'}, 'val_id': {'src': '/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_id/', 'split': 'all'}, 'val_ood_cat': {'src': '/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_cat/', 'split': 'all'}, 'val_ood_ads': {'src': '/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_ads/', 'split': 'all'}, 'val_ood_both': {'src': '/network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_ood_both/', 'split': 'all'}} + + +def prepare_for_gfn(ckpt_paths: dict, release: str, test_dataset_name=None) -> tuple: """ Prepare a FAENet model for use in GFN. Loads the checkpoint for the given release on the current host, and wraps it in a FAENetWrapper. @@ -269,14 +288,23 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: """ ckpt_path = find_ckpt(ckpt_paths, release) assert ckpt_path.exists(), f"Path {ckpt_path} does not exist." + if test_dataset_name: + overrides = { + "is_debug": True, + "silent": True, + "cp_data_to_tmpdir": False, + "dataset": get_test_dataset_configs(test_dataset_name) + } + else: + overrides = { + "is_debug": True, + "silent": True, + "cp_data_to_tmpdir": False + } trainer = make_trainer_from_dir( ckpt_path, mode="continue", - overrides={ - "is_debug": True, - "silent": True, - "cp_data_to_tmpdir": False, - }, + overrides=overrides, silent=True, skip_imports=["qm7x", "gemnet", "spherenet", "painn", "comenet"] ) @@ -313,12 +341,16 @@ def to_data_list(batch): release = "0.0.1" # or ckpt_paths = { - "mila": "/network/projects/crystalgfn/catalyst/checkpoints/best_checkpoint_depfaenet.pt", + # "mila": "/network/projects/crystalgfn/catalyst/checkpoints/best_checkpoint_depfaenet.pt", + "mila":"/network/scratch/e/elena.podina/ocp/runs/6604136/checkpoints/best_checkpoint.pt", # LP trained this one on OC20 + # "mila": "/home/mila/e/elena.podina/scratch/ocp/runs/6549923/checkpoints/best_checkpoint.pt", + # "mila": "/network/projects/crystalgfn/catalyst/checkpoints/best_checkpoint_OER.pt", "lpodina": "/home/felixt/shared/checkpoints/best_checkpoint.pt", "narval": "/home/felixt/shared/checkpoints/best_checkpoint.pt" } release = None - wrapper, loaders = prepare_for_gfn(ckpt_paths, release) + test_datset_name = "OC20" + wrapper, loaders = prepare_for_gfn(ckpt_paths, release,test_dataset_name=test_datset_name) wrapper.eval() # data_gen_ood_cat = iter(loaders["val_ood_cat"]) @@ -329,12 +361,30 @@ def to_data_list(batch): print("Testing batches...") batch_i = 0 - while batch := next(train_set_iterator): + y_relaxed= [] + while batch := next(data_gen_id): print(f"{batch_i=}") preds_1 = wrapper(deepcopy(batch)).detach().cpu().numpy() true_1 = np.array([b.y_relaxed for b in batch]).flatten() + print(f"{preds_1=}") + print(f"{true_1=}") + print(f"{np.mean(true_1)=}") + print(f"{(preds_1 - true_1)=}") print(f"Test batch {batch_i} mae: {np.mean(np.abs(preds_1 - true_1))=}") + print(f"Test batch {batch_i} mae (only -5 to 5): {np.mean(np.abs(preds_1[(-5 < true_1) & (true_1 < 5)] - true_1[(-5 < true_1) & (true_1 < 5)]))=}") batch_i += 1 + # if batch_i >= 5: + # break + plt.plot(true_1.flatten(),preds_1.flatten(),'o',label=f'batch = {batch_i}') + plt.grid() + plt.plot([-10, 10], [-10, 10]) + plt.savefig(f'test_{test_datset_name}_gfn') + plt.clf() + # print(f"Mean of batch {batch_i}: ",np.mean(true_1)) + # print(f"Stdev of batch {batch_i}: ",np.std(true_1)) + # y_relaxed += [b.y_relaxed for b in batch] + # print(f"Mean: ",np.mean(y_relaxed)) + # print(f"Stdev: ",np.std(y_relaxed)) exit(1) # print("Testing val id 10 batches...") From 64d6fe07e19eafe16b081a8736e6819afd2c1f89 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Sun, 20 Apr 2025 11:47:03 -0400 Subject: [PATCH 44/45] new config adds --- ocpmodels/modules/evaluator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ocpmodels/modules/evaluator.py b/ocpmodels/modules/evaluator.py index 13d4573e31..d7cfb64e43 100644 --- a/ocpmodels/modules/evaluator.py +++ b/ocpmodels/modules/evaluator.py @@ -58,6 +58,11 @@ class Evaluator: "energy_mse", "energy_within_threshold", ], + "is2re_oc22": [ + "energy_mae", + "energy_mse", + "energy_within_threshold", + ], "qm9": [ "energy_mae", "energy_mse", @@ -78,6 +83,7 @@ class Evaluator: "is2rs": ["positions", "cell", "pbc", "natoms"], "is2re": ["energy"], "is2re_oc20": ["energy"], + "is2re_oc22": ["energy"], "qm9": ["energy"], "qm7x": ["energy"], "deup_is2re": ["energy"], From fd95cd8732d3e65ef60cc2018b52b8128965d774 Mon Sep 17 00:00:00 2001 From: Lena Podina Date: Wed, 23 Apr 2025 10:36:05 -0400 Subject: [PATCH 45/45] OC20+22 setup, removed slurm job kill --- configs/models/depfaenet.yaml | 44 +++++++++++++++++++++++ configs/models/tasks/is2re_both.yaml | 53 ++++++++++++++++++++++++++++ main.py | 6 ++-- ocpmodels/datasets/lmdb_dataset.py | 9 ++++- ocpmodels/modules/evaluator.py | 9 ++++- 5 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 configs/models/tasks/is2re_both.yaml diff --git a/configs/models/depfaenet.yaml b/configs/models/depfaenet.yaml index 6cba51a7ad..039a7e452c 100644 --- a/configs/models/depfaenet.yaml +++ b/configs/models/depfaenet.yaml @@ -141,6 +141,50 @@ is2re_oc22: warmup_steps: 6000 max_epochs: 20 +is2re_both: + # *** Important note *** + # The total number of gpus used for this run was 1. + # If the global batch size (num_gpus * batch_size) is modified + # the lr_milestones and warmup_steps need to be adjusted accordingly. + 10k: + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + 100k: + model: + hidden_channels: 256 + optim: + lr_initial: 0.005 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 1562 + - 2343 + - 3125 + warmup_steps: 468 + max_epochs: 20 + + all: + model: + hidden_channels: 384 + num_interactions: 4 + optim: + batch_size: 256 + eval_batch_size: 256 + lr_initial: 0.001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 18000 + - 27000 + - 37000 + warmup_steps: 6000 + max_epochs: 20 + + is2re_oc20: # *** Important note *** # The total number of gpus used for this run was 1. diff --git a/configs/models/tasks/is2re_both.yaml b/configs/models/tasks/is2re_both.yaml new file mode 100644 index 0000000000..d539e943bb --- /dev/null +++ b/configs/models/tasks/is2re_both.yaml @@ -0,0 +1,53 @@ +default: + trainer: single + logger: wandb + + task: + dataset: single_point_lmdb + description: "Relaxed state energy prediction from initial structure OC22 + OC20." + type: regression + metric: mae + labels: + - relaxed energy + optim: + optimizer: AdamW + normalizer: null + model: + otf_graph: False + max_num_neighbors: 40 + mode: train + adsorbates: all # {"*O", "*OH", "*OH2", "*H"} + adsorbates_ref_dir: /network/scratch/s/schmidtv/ocp/datasets/ocp/per_ads + load_datasets: True + dataset: + default_val: val_id + train: + src: /network/projects/crystalgfn/catalyst/ocp/oc20_and_22/train/ + normalize_labels: True + target_mean: -1.525913953781128 + target_std: 2.279365062713623 + val_id: + src: /network/projects/crystalgfn/catalyst/ocp/oc20_and_22/val_id/ + val_ood_cat: + src: /network/projects/crystalgfn/catalyst/ocp/oc20_and_22/val_ood/ + val_ood_ads: + src: /network/projects/crystalgfn/catalyst/ocp/oc20_and_22/val_ood/ + val_ood_both: + src: /network/projects/crystalgfn/catalyst/ocp/oc20_and_22/val_ood/ + # DEUP + deup_dataset: + create: False # "before" -> created before training (for deup) "after" -> created after training (for is2re) "" - not created + dataset_strs: ["train", "val_id", "val_ood_cat", "val_ood_ads"] + n_samples: 7 + +10k: + dataset: + train: + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/10k/train + +100k: + dataset: + train: + src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/100k/train + +all: {} \ No newline at end of file diff --git a/main.py b/main.py index c28b88ef36..c3eae8c657 100644 --- a/main.py +++ b/main.py @@ -67,9 +67,9 @@ def wrap_up(args, start_time, error=None, signal=None, trainer=None): dist_utils.cleanup() print("Done!") - if "interactive" not in os.popen(f"squeue -hj {JOB_ID}").read(): - print("\nSelf-canceling SLURM job in 32s", JOB_ID) - os.popen(f"sleep 32 && scancel {JOB_ID}") + # if "interactive" not in os.popen(f"squeue -hj {JOB_ID}").read(): + # print("\nSelf-canceling SLURM job in 32s", JOB_ID) + # os.popen(f"sleep 32 && scancel {JOB_ID}") if trainer and trainer.logger: trainer.logger.finish(error or signal) diff --git a/ocpmodels/datasets/lmdb_dataset.py b/ocpmodels/datasets/lmdb_dataset.py index 8f8fb24442..997ae0727b 100644 --- a/ocpmodels/datasets/lmdb_dataset.py +++ b/ocpmodels/datasets/lmdb_dataset.py @@ -330,7 +330,14 @@ def __init__(self, config, transform=None): def data_list_collater(data_list, otf_graph=False): - batch = Batch.from_data_list(data_list) + for d in data_list: + if isinstance(d.natoms, torch.Tensor): + d.natoms = d.natoms.item() + d.sid = d.sid.item() + # try: + batch = Batch.from_data_list(data_list,exclude_keys=['force','y_init','nads','oc22','distance_vec']) + # except: + # breakpoint() if ( not otf_graph diff --git a/ocpmodels/modules/evaluator.py b/ocpmodels/modules/evaluator.py index d7cfb64e43..033f0e9c29 100644 --- a/ocpmodels/modules/evaluator.py +++ b/ocpmodels/modules/evaluator.py @@ -63,6 +63,11 @@ class Evaluator: "energy_mse", "energy_within_threshold", ], + "is2re_both": [ + "energy_mae", + "energy_mse", + "energy_within_threshold", + ], "qm9": [ "energy_mae", "energy_mse", @@ -84,6 +89,7 @@ class Evaluator: "is2re": ["energy"], "is2re_oc20": ["energy"], "is2re_oc22": ["energy"], + "is2re_both": ["energy"], "qm9": ["energy"], "qm7x": ["energy"], "deup_is2re": ["energy"], @@ -95,13 +101,14 @@ class Evaluator: "is2re": "energy_mae", "is2re_oc20": "energy_mae", "is2re_oc22": "energy_mae", + "is2re_both": "energy_mae", "qm9": "energy_mae", "qm7x": "energy_mae", "deup_is2re": "energy_mse", } def __init__(self, task=None, model_regresses_forces=""): - assert task in ["s2ef", "is2rs", "is2re", "qm9", "qm7x", "deup_is2re", "is2re_oc22", "is2re_oc20"] + assert task in ["s2ef", "is2rs", "is2re", "qm9", "qm7x", "deup_is2re", "is2re_oc22", "is2re_oc20", "is2re_both"] self.task = task self.metric_fn = self.task_metrics[task]