Skip to content

[WIP] Don't rely on configs in prepare_for_gfn #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: uncertainty
Choose a base branch
from
1 change: 1 addition & 0 deletions configs/models/tasks/is2re.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
default:
trainer: single
logger: wandb
prevent_load: {}

task:
dataset: single_point_lmdb
Expand Down
1 change: 1 addition & 0 deletions configs/models/tasks/qm7x.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
default:
trainer: single
logger: wandb
prevent_load: {}
eval_on_test: True

model:
Expand Down
1 change: 1 addition & 0 deletions configs/models/tasks/qm9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ default:
trainer: single
logger: wandb
eval_on_test: True
prevent_load: {}

model:
otf_graph: False
Expand Down
2 changes: 2 additions & 0 deletions configs/models/tasks/s2ef.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
default:
trainer: single
logger: wandb
prevent_load: {}

task:
dataset: trajectory_lmdb
description: "Regressing to energies and forces for DFT trajectories from OCP"
Expand Down
70 changes: 53 additions & 17 deletions ocpmodels/common/gfn.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
from copy import deepcopy
from pathlib import Path
from typing import Callable, Union, List

import os
from typing import Callable, List, Union

import torch
import torch.nn as nn
from torch_geometric.data.data import Data
from torch_geometric.data.batch import Batch
from torch_geometric.data.data import Data

from ocpmodels.common.utils import make_trainer_from_dir, resolve
from ocpmodels.models.faenet import FAENet
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import resolve, setup_imports
from ocpmodels.datasets.data_transforms import get_transforms
from ocpmodels.models.faenet import FAENet


class FAENetWrapper(nn.Module):
Expand Down Expand Up @@ -190,6 +191,37 @@ def parse_loc() -> str:
return loc


def reset_data_paths(config):
"""
Reset config data paths to defaults, instead of SLURM temporary paths (inplace).

Args:
config (dict): The trainer config dictionary to modify.

Returns:
dict: The modified config dictionary.
"""
ds_configs = deepcopy(config["dataset"])
task_name = config["task"]["name"]
if task_name != "is2re":
raise NotImplementedError(
"Only the is2re task is currently supported for resetting data paths."
+ " To implement this for other tasks, modify how `base_path` is constructed"
" in `reset_data_paths()`"
)
base_path = Path("/network/projects/ocp/oc20/is2re")
for name, ds_config in ds_configs.items():
if not isinstance(ds_config, dict):
continue
if "slurm" in ds_config["src"].lower():
ds_config["src"] = str(
base_path / ds_config["split"] / Path(ds_config["src"]).name
)
config["dataset"][name] = ds_config

return config


def find_ckpt(ckpt_paths: dict, release: str) -> Path:
"""
Finds a checkpoint in a dictionary of paths, based on the current cluster name and
Expand Down Expand Up @@ -223,7 +255,7 @@ def find_ckpt(ckpt_paths: dict, release: str) -> Path:
if path.is_file():
return path
path = path / release
ckpts = list(path.glob("**/*.ckpt"))
ckpts = list(path.glob("**/*.pt"))
if len(ckpts) == 0:
raise ValueError(f"No FAENet proxy checkpoint found at {str(path)}.")
if len(ckpts) > 1:
Expand Down Expand Up @@ -256,18 +288,22 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple:
Returns:
tuple: (model, loaders) where loaders is a dict of loaders for the model.
"""
setup_imports()
ckpt_path = find_ckpt(ckpt_paths, release)
assert ckpt_path.exists(), f"Path {ckpt_path} does not exist."
trainer = make_trainer_from_dir(
ckpt_path,
mode="continue",
overrides={
"is_debug": True,
"silent": True,
"cp_data_to_tmpdir": False,
},
silent=True,
)
config = torch.load(ckpt_path, map_location="cpu")["config"]
config["is_debug"] = True
config["silent"] = True
config["cp_data_to_tmpdir"] = False
config["prevent_load"] = {
"logger": True,
"loss": True,
"datasets": True,
"optimizer": True,
"extras": True,
}
config = reset_data_paths(config)
trainer = registry.get_trainer_class(config["trainer"])(**config)

wrapper = FAENetWrapper(
faenet=trainer.model,
Expand Down
23 changes: 16 additions & 7 deletions ocpmodels/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,16 @@ def add_edge_distance_to_graph(


# Copied from https://github.com/facebookresearch/mmf/blob/master/mmf/utils/env.py#L89.
def setup_imports(skip_imports=[]):
def setup_imports(skip_modules=[]):
"""Automatically load all of the modules, so that they register within the registry.

Parameters
----------
skip_modules : list, optional
List of modules (as ``str``) to skip while importing, by default []. Use module
names not paths, for instance, to skip ``ocpmodels.models.gemnet_oc.gemnet_oc``,
use ``skip_modules=["gemnet_oc"]``.
"""
from ocpmodels.common.registry import registry

try:
Expand Down Expand Up @@ -803,7 +812,7 @@ def setup_imports(skip_imports=[]):
splits = f.split(os.sep)
file_name = splits[-1]
module_name = file_name[: file_name.find(".py")]
if module_name not in skip_imports:
if module_name not in skip_modules:
importlib.import_module("ocpmodels.%s.%s" % (key[1:], module_name))

# manual model imports
Expand Down Expand Up @@ -1191,7 +1200,7 @@ def build_config(args, args_override=[], dict_overrides={}, silent=None):

# load config from `model-task-split` pattern
config = load_config(args.config)
# overwride with command-line args, including default values
# override with command-line args, including default values
config = merge_dicts(config, args_dict_with_defaults)
# override with build_config()'s overrides
config = merge_dicts(config, overrides)
Expand Down Expand Up @@ -1801,7 +1810,7 @@ def make_script_trainer(str_args=[], overrides={}, silent=False, mode="train"):
return trainer


def make_config_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]):
def make_config_from_dir(path, mode, overrides={}, silent=None, skip_modules=[]):
"""
Make a config from a directory. This is useful when restarting or continuing from a
previous run.
Expand Down Expand Up @@ -1838,11 +1847,11 @@ def make_config_from_dir(path, mode, overrides={}, silent=None, skip_imports=[])
config = build_config(default_args, silent=silent)
config = merge_dicts(config, overrides)

setup_imports(skip_imports=skip_imports)
setup_imports(skip_modules=skip_modules)
return config


def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]):
def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_modules=[]):
"""
Make a trainer from a directory.

Expand All @@ -1858,7 +1867,7 @@ def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]
Returns:
Trainer: The loaded trainer.
"""
config = make_config_from_dir(path, mode, overrides, silent, skip_imports)
config = make_config_from_dir(path, mode, overrides, silent, skip_modules)
return registry.get_trainer_class(config["trainer"])(**config)


Expand Down
13 changes: 12 additions & 1 deletion ocpmodels/datasets/qm7x.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,20 @@
from torch_geometric.data import Data
from tqdm import tqdm

from cosmosis.dataset import CDataset
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import ROOT

CDataset = object
try:
from cosmosis.dataset import CDataset
except ImportError:
print(
"Warning: `cosmosis` is not installed. `QM7X` will not be available.",
"See https://github.com/icanswim/cosmosis",
)
print(f"(message from {Path(__file__).resolve()})")


try:
import orjson as json # noqa: F401
except: # noqa: E722
Expand All @@ -33,6 +43,7 @@
"`orjson` is not installed. ",
"Consider `pip install orjson` to speed up json loading.",
)
print(f"(message from {Path(__file__).resolve()})")


class Molecule:
Expand Down
16 changes: 13 additions & 3 deletions ocpmodels/models/comenet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from dig.threedgraph.method import ComENet as DIGComENet
from ocpmodels.models.base_model import BaseModel
from copy import deepcopy

import torch

from ocpmodels.common.registry import registry
from ocpmodels.common.utils import conditional_grad
from copy import deepcopy
from ocpmodels.models.base_model import BaseModel

DIGComENet = None
try:
from dig.threedgraph.method import ComENet as DIGComENet
except ImportError:
from pathlib import Path

print("Warning: `dig` is not installed. `SphereNet` will not be available.")
print(f"(message from {Path(__file__).resolve()})\n")


@registry.register_model("comenet")
Expand Down
16 changes: 13 additions & 3 deletions ocpmodels/models/spherenet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from dig.threedgraph.method import SphereNet as DIGSphereNet
from ocpmodels.models.base_model import BaseModel
from copy import deepcopy

import torch

from ocpmodels.common.registry import registry
from ocpmodels.common.utils import conditional_grad
from copy import deepcopy
from ocpmodels.models.base_model import BaseModel

DIGSphereNet = None
try:
from dig.threedgraph.method import SphereNet as DIGSphereNet
except ImportError:
from pathlib import Path

print("Warning: `dig` is not installed. `SphereNet` will not be available.")
print(f"(message from {Path(__file__).resolve()})\n")


@registry.register_model("spherenet")
Expand Down
5 changes: 4 additions & 1 deletion ocpmodels/modules/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""scheduler.py
"""

import inspect

import torch.optim.lr_scheduler as lr_scheduler

from ocpmodels.common.utils import warmup_lr_lambda
import pytorch_warmup as warmup


class LRScheduler:
Expand Down Expand Up @@ -54,6 +55,8 @@ def scheduler_lambda_fn(x):
if not self.silent:
print(f"Using fidelity_max_steps for scheduler -> {T_max}")
if self.optim_config["warmup_steps"] > 0:
import pytorch_warmup as warmup

self.warmup_scheduler = warmup.ExponentialWarmup(
self.optimizer, warmup_period=self.optim_config["warmup_steps"]
)
Expand Down
Loading