Skip to content

FAENet++ #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 126 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
6f81484
ewald message passing raw
Ramlaoui Dec 5, 2023
eaf80a4
notebook to compare activations between models
Ramlaoui Dec 14, 2023
3f444e7
quick addition of ewald to schnet
Ramlaoui Dec 14, 2023
07ba182
update activations notebook
Ramlaoui Feb 7, 2024
4b88715
added indications for future implementations
theosaulus Feb 7, 2024
23f2c6a
added dirty implementation of learned canonicalisation function
theosaulus Feb 7, 2024
3750d1c
removed comments
theosaulus Feb 7, 2024
38ae651
bug fix
theosaulus Feb 7, 2024
d8b80e3
modify configs for ewald
Ramlaoui Feb 16, 2024
e520dea
ewald model for schnet and faenet
Ramlaoui Feb 16, 2024
6ecd7d5
New tasks for S2EF relaxations
Ramlaoui Feb 16, 2024
6879391
update relaxations to current ocp updates
Ramlaoui Feb 17, 2024
56796e7
adapt for relaxations
Ramlaoui Feb 17, 2024
c758b0b
update logger name for ft
Ramlaoui Feb 17, 2024
1306850
update task
Ramlaoui Feb 17, 2024
7382cb5
add _unwrapped_model for relaxations
Ramlaoui Feb 23, 2024
5ce0486
fix results directory for the current config
Ramlaoui Feb 23, 2024
bb4d388
otf_graph in faenet class
Ramlaoui Feb 23, 2024
eceafc5
fix bug for relaxations metrics
Ramlaoui Feb 29, 2024
ef619f8
revoke debug options
Ramlaoui Feb 29, 2024
dda4c54
added othornormalization
theosaulus Feb 29, 2024
4c38daa
modif sh
theosaulus Feb 29, 2024
9dcd071
use directories with debug, revert after
Ramlaoui Feb 29, 2024
0e2c35c
init separated UntrainedCanonicalisation class
theosaulus Feb 29, 2024
c6a09bf
currently not working - nodes with no edges
theosaulus Feb 29, 2024
da5ba0a
dive deeper in analysis
Ramlaoui Mar 20, 2024
e3589ef
change notebook strucctures + studies on mispredictions
Ramlaoui Mar 21, 2024
d33e514
added canonicalisation support - not running
theosaulus Mar 21, 2024
6e9dfc3
uniform canocalisation calls
Ramlaoui Mar 21, 2024
7a16a32
fix bugs - now working
Ramlaoui Mar 21, 2024
4026dd3
add analysis notebooks
Ramlaoui Apr 4, 2024
cf974e0
better support of various equivariance methods - added vn point net
theosaulus Apr 4, 2024
b767541
reverted to original
theosaulus Apr 4, 2024
0eb02a7
comments cleaning
theosaulus Apr 4, 2024
cdb618d
add more systems taken into account for std plots
Ramlaoui Apr 5, 2024
57963ae
added support for learnable canonicalization functions
theosaulus Apr 5, 2024
e928ea5
add equiformer to models folder - not running oom
Ramlaoui Apr 5, 2024
5dd6d78
equiformerv2 config
Ramlaoui Apr 5, 2024
7beb0ef
schedulers for equiformer_v2
Ramlaoui Apr 5, 2024
25cd6ef
draft for new vn deeper model
theosaulus Apr 11, 2024
b62034b
bug fix
theosaulus Apr 11, 2024
4acebcf
cleaning
theosaulus Apr 11, 2024
a26fcfb
minor changes
theosaulus Apr 11, 2024
3338f99
attempt to solve the oom problem - still occuring
theosaulus Apr 11, 2024
d2393bf
another attempt at solving the oom
theosaulus Apr 11, 2024
e933ca7
minor changes
theosaulus Apr 11, 2024
b02b4c1
modified model_forward in validate function to set it to inference mode
theosaulus Apr 12, 2024
36da25e
minor changes
theosaulus Apr 12, 2024
2c03afe
Merge branch 's2ef-relax' into equiformer
Ramlaoui Apr 12, 2024
ec35186
corrected cano_network initialization
theosaulus Apr 12, 2024
7feca9f
cleaning
theosaulus Apr 14, 2024
4f54c25
correcting cano_model optimization
theosaulus Apr 14, 2024
9d5060a
cleaning
theosaulus Apr 15, 2024
0a418c9
fixed invariance tests due to wrong device placement
theosaulus Apr 15, 2024
4b89353
initialization for dgcnn cano model
theosaulus Apr 15, 2024
350a4a6
temporary fix to avoid errors due to 2D cano
theosaulus Apr 15, 2024
e221810
fixed dgcnn model - working
theosaulus Apr 15, 2024
7aebabe
cleaning
theosaulus Apr 15, 2024
4ba9658
fix relaxation for faenet
Ramlaoui Apr 16, 2024
290c319
add flag for reloading config from run folder
Ramlaoui Apr 16, 2024
7368bd1
support reload_config flag
Ramlaoui Apr 16, 2024
5ae4a93
frame averaging to work on GPU (useful for relaxations)
Ramlaoui Apr 16, 2024
89b1e41
script to run relaxations
Ramlaoui Apr 16, 2024
c67bf99
scheduler to work with default equiformer config
Ramlaoui Apr 16, 2024
6f48800
fixed device when testing s2ef equivariances
theosaulus Apr 16, 2024
93d03d9
minor changes
theosaulus Apr 16, 2024
3d8aad8
fix default scheduler
Ramlaoui Apr 16, 2024
5e1f482
is2re for equiformer config
Ramlaoui Apr 16, 2024
67c2c10
large cleaning + added simple cano_method
theosaulus Apr 17, 2024
100a425
remove gpu-per-task because it crashes
Ramlaoui Apr 18, 2024
0e81854
sbatch file for new slurm updates
Ramlaoui Apr 23, 2024
7cd9647
update notebook for more comparisons
Ramlaoui May 10, 2024
e7abdab
Merge branch 'learn-cano-func' into equiformer
Ramlaoui May 10, 2024
c16b325
fix no equivariance module bug
Ramlaoui May 11, 2024
02dc54f
fix no equivariance module bug
Ramlaoui May 11, 2024
c9ffb9d
Merge branch 'learn-cano-func' into equiformer
Ramlaoui May 11, 2024
197dcf1
fix no equivariance module bug
Ramlaoui May 11, 2024
b58598e
Merge branch 'learn-cano-func' into equiformer
Ramlaoui May 11, 2024
5a80333
adjust model to work with faenet pipeline
Ramlaoui May 13, 2024
4d5f80c
use collater instead of from list to fix bugs
Ramlaoui May 13, 2024
76c4709
Merge branch 'learn-cano-func' into equiformer
Ramlaoui May 13, 2024
3529edb
collater usage fix
Ramlaoui May 15, 2024
ffaa7a3
Merge branch 'learn-cano-func' into equiformer
Ramlaoui May 15, 2024
a3af424
fix dataset when adsorbate is not specified
Ramlaoui May 15, 2024
e128db7
support for when cell is None
Ramlaoui May 15, 2024
9fca575
added support for Sign Equivariant SFA
theosaulus May 15, 2024
f7e02ef
fix support for qm7x
Ramlaoui May 15, 2024
2f24ee2
bug fix
theosaulus May 15, 2024
0909cff
debug
theosaulus May 16, 2024
8105827
centering pos before trained cano to have E(3) instead of O(3)
theosaulus May 19, 2024
ba951e6
testing signnet - not the right implem
theosaulus May 19, 2024
f011098
playing with VN pointnet
theosaulus May 19, 2024
24e38f9
minor changes
theosaulus May 19, 2024
774d86e
added support for QM7x and QM9
theosaulus May 19, 2024
38d3844
added support for sign_equiv_sfa, and renamed sign_inv_sfa appropriately
theosaulus May 19, 2024
54a3714
make relaxations work for new canonicalization
Ramlaoui May 20, 2024
9f03601
relaxation when model from cluster to cp_tmp
Ramlaoui May 20, 2024
eccdd7a
update plots order
Ramlaoui May 20, 2024
d9cf179
add sum of interaction layers
Ramlaoui May 20, 2024
cc7cc04
bug fix
theosaulus May 20, 2024
bb8c6f7
bug fix
theosaulus May 20, 2024
64887f6
fix fa when the positions are on gpu
Ramlaoui May 21, 2024
a3fd7dc
transforms for relaxation in the case of fa
Ramlaoui May 21, 2024
bc24537
update notebook
Ramlaoui May 21, 2024
88abff3
avoid splitting the batch when no trainable cano
Ramlaoui May 22, 2024
384f910
fix + gram schmidt
theosaulus May 28, 2024
48aec96
direct forces and support for qm7x sym
theosaulus May 28, 2024
5141584
minor changes
theosaulus May 28, 2024
1066e48
clean + added sign-inv E3-equiv net
theosaulus Jun 19, 2024
8b1a770
cleaning
theosaulus Jun 29, 2024
2cece77
added support for sign invariant E(3)-equivariant cano
theosaulus Jun 29, 2024
8eb0f73
add extra argument used on energy forward by the trainer
Ramlaoui Jun 30, 2024
c0d9200
allow to freeze backbone for transfer learning
Ramlaoui Jun 30, 2024
1f75cf1
linting
Ramlaoui Jun 30, 2024
8647ef2
revert debug case to defau+t + lint
Ramlaoui Jun 30, 2024
2af5acd
Merge branch 'ewald-sum' into equiformer
Ramlaoui Jun 30, 2024
6bc378e
add ewald to dpp
Ramlaoui Jun 30, 2024
640166c
cleaning
theosaulus Jul 7, 2024
52eabe6
Merge branch 'equiformer' into learn-cano-func
Ramlaoui Jul 7, 2024
b750be1
Merge pull request #65 from RolnickLab/learn-cano-func
Ramlaoui Jul 7, 2024
4ead6db
support for ewald with dpp and qm9
Ramlaoui Jul 7, 2024
c40b5aa
fix memory leak for relaxations without transforms
Ramlaoui Jul 7, 2024
e9c39a8
add compatibility for old frame_averaging arguments in configs
Ramlaoui Jul 7, 2024
442be83
Merge pull request #57 from RolnickLab/equiformer
Ramlaoui Jul 7, 2024
52dc4cc
cleaning and comments
theosaulus Sep 3, 2024
5cfe47a
clean
theosaulus Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions configs/models/dpp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ default:
phys_embeds: False # True
phys_hidden_channels: 0
energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds}
######################################### Ewald message passing hyperparameters
use_atom_to_atom_mp: False
use_ewald: False
detach_ewald: True
ewald_hyperparams:
num_k_x: 2 # Frequency cutoff [Å^-1]
num_k_y: 2 # Voxel grid resolution [Å^-1]
num_k_z: 5 # Gaussian radial basis size (Fourier filter)
downprojection_size: 8 # Size of linear bottleneck layer
num_hidden: 3 # Number of residuals in update function
# qm9
k_cutoff: 0.8 # Frequency cutoff [Å^-1]
delta_k: 0.2 # Voxel grid resolution [Å^-1]
num_k_rbf: 128 # Gaussian radial basis size (Fourier filter)
#########################################
optim:
batch_size: 4
eval_batch_size: 4
Expand Down Expand Up @@ -82,7 +97,7 @@ is2re:
s2ef:
default:
model:
regress_forces: "from_energy"
regress_forces: "direct"
force_decoder_type: "mlp" # can be {"" or "simple"} | only used if regress_forces is True
force_decoder_model_config:
simple:
Expand Down Expand Up @@ -127,19 +142,19 @@ s2ef:
# If the global batch size (num_gpus * batch_size) is modified
# the lr_milestones and warmup_steps need to be adjusted accordingly.
optim:
batch_size: 96
eval_batch_size: 96
batch_size: 48
eval_batch_size: 48
eval_every: 10000
num_workers: 8
lr_initial: 0.0001
lr_gamma: 0.1
lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma
- 20833
- 31250
- 41666
- 10433
- 15500
- 21633
warmup_steps: 10416
warmup_factor: 0.2
max_epochs: 3
max_epochs: 15
force_coefficient: 50
energy_coefficient: 1
energy_grad_coefficient: 5
Expand Down
86 changes: 86 additions & 0 deletions configs/models/equiformer_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# includes:
# - configs/s2ef/2M/base.yml

# trainer: equiformerv2_forces
default:
model:
name: equiformer_v2

use_pbc: True
regress_forces: True
otf_graph: True
max_neighbors: 20
max_radius: 12.0
max_num_elements: 90

num_layers: 12
sphere_channels: 128
attn_hidden_channels: 64 # [64, 96] This determines the hidden size of message passing. Do not necessarily use 96.
num_heads: 8
attn_alpha_channels: 64 # Not used when `use_s2_act_attn` is True.
attn_value_channels: 16
ffn_hidden_channels: 128
norm_type: 'layer_norm_sh' # ['rms_norm_sh', 'layer_norm', 'layer_norm_sh']

lmax_list: [6]
mmax_list: [2]
grid_resolution: 18 # [18, 16, 14, None] For `None`, simply comment this line.

num_sphere_samples: 128

edge_channels: 128
use_atom_edge_embedding: True
share_atom_edge_embedding: False # If `True`, `use_atom_edge_embedding` must be `True` and the atom edge embedding will be shared across all blocks.
distance_function: 'gaussian'
num_distance_basis: 512 # not used

attn_activation: 'silu'
use_s2_act_attn: False # [False, True] Switch between attention after S2 activation or the original EquiformerV1 attention.
use_attn_renorm: True # Attention re-normalization. Used for ablation study.
ffn_activation: 'silu' # ['silu', 'swiglu']
use_gate_act: False # [True, False] Switch between gate activation and S2 activation
use_grid_mlp: True # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs.
use_sep_s2_act: True # Separable S2 activation. Used for ablation study.

alpha_drop: 0.1 # [0.0, 0.1]
drop_path_rate: 0.05 # [0.0, 0.05]
proj_drop: 0.0

weight_init: 'uniform' # ['uniform', 'normal']

optim:
batch_size: 1 # 6
eval_batch_size: 1 # 6
load_balancing: atoms
num_workers: 8
lr_initial: 0.0004 # [0.0002, 0.0004], eSCN uses 0.0008 for batch size 96

optimizer: AdamW
weight_decay: 0.001
scheduler: LambdaLR
lambda_type: cosine
warmup_factor: 0.2
warmup_epochs: 0.1
lr_min_factor: 0.01 #

max_epochs: 30
force_coefficient: 100
energy_coefficient: 2
clip_grad_norm: 100
ema_decay: 0.999
loss_energy: mae
loss_force: l2mae

eval_every: 5000

s2ef:
default: {}

2M: {}

is2re:
default:
model:
regress_forces: False

all: {}
24 changes: 22 additions & 2 deletions configs/models/faenet.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
default:
cano_args:
equivariance_module: "" # "", fa, untrained_cano, trained_cano, sign_equiv_sfa, trained_sign_inv_sfa, untrained_sign_inv_sfa, trained_sign_inv_sfa_E3, untrained_sign_inv_sfa_E3
cano_type: "3D" # "2D", "3D", "DA", ""
cano_method: "" # "", pointnet, simple (= 0 hidden layer pointnet), dgcnn
# Frame averaging specific arguments
fa_method: "random" # {"", all, random, det, se3-all, se3-random, se3-det}
# Legacy FA arguments:
frame_averaging: "" # {"2D", "3D", "DA", ""}
fa_method: "" # {"", all, randon, det, se3-all, se3-randon, se3-det}
fa_method: "" # {"", all, random, det, se3-all, se3-random, se3-det}
model:
name: faenet
act: swish
Expand Down Expand Up @@ -41,6 +48,19 @@ default:
res_updown:
hidden_channels: 128
norm: batch1d # batch1d, layer or null
######################################### Ewald message passing hyperparameters
use_ewald: False
ewald_hyperparams:
num_k_x: 1 # Frequency cutoff [Å^-1]
num_k_y: 1 # Voxel grid resolution [Å^-1]
num_k_z: 3 # Gaussian radial basis size (Fourier filter)
downprojection_size: 8 # Size of linear bottleneck layer
num_hidden: 3 # Number of residuals in update function
# params for qm9
k_cutoff: 0.4 # Frequency cutoff [Å^-1]
delta_k: 0.2 # Voxel grid resolution [Å^-1]
num_k_rbf: 48 # Gaussian radial basis size (Fourier filter)
#########################################
optim:
batch_size: 256
eval_batch_size: 256
Expand Down Expand Up @@ -69,7 +89,7 @@ is2re:
default:
graph_rewiring: remove-tag-0
frame_averaging: "2D" # {"2D", "3D", "DA", ""}
fa_method: "se3-random" # {"", all, randon, det, se3-all, se3-randon, se3-det}
fa_method: "se3-random" # {"", all, random, det, se3-all, se3-random, se3-det}
# *** Important note ***
# The total number of gpus used for this run was 1.
# If the global batch size (num_gpus * batch_size) is modified
Expand Down
36 changes: 35 additions & 1 deletion configs/models/painn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,41 @@ is2re:
s2ef:
default: {}
200k: {}
2M: {}
2M:
model:
name: painn
hidden_channels: 512
num_layers: 6
num_rbf: 128
cutoff: 12.0
max_neighbors: 50
# scale_file: configs/s2ef/all/painn/painn_nb6_scaling_factors.pt
regress_forces: True
direct_forces: True
use_pbc: True

optim:
batch_size: 32
eval_batch_size: 32
load_balancing: atoms
eval_every: 5000
num_workers: 2
optimizer: AdamW
optimizer_params:
amsgrad: True
weight_decay: 0. # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2
lr_initial: 1.e-4
lr_gamma: 0.8
scheduler: ReduceLROnPlateau
mode: min
factor: 0.8
patience: 3
max_epochs: 80
force_coefficient: 100
energy_coefficient: 1
ema_decay: 0.999
clip_grad_norm: 10

20M: {}
all: {}

Expand Down
13 changes: 13 additions & 0 deletions configs/models/schnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@ default:
phys_embeds: False # True
phys_hidden_channels: 0
energy_head: False # can be {False, weighted-av-initial-embeds, weighted-av-final-embeds, random}
######################################### Ewald message passing hyperparameters
use_ewald: False
ewald_hyperparams:
num_k_x: 1 # Frequency cutoff [Å^-1]
num_k_y: 1 # Voxel grid resolution [Å^-1]
num_k_z: 3 # Gaussian radial basis size (Fourier filter)
downprojection_size: 8 # Size of linear bottleneck layer
num_hidden: 3 # Number of residuals in update function
# params for qm9
k_cutoff: 0.4 # Frequency cutoff [Å^-1]
delta_k: 0.2 # Voxel grid resolution [Å^-1]
num_k_rbf: 48 # Gaussian radial basis size (Fourier filter)
#########################################
optim:
batch_size: 64
eval_batch_size: 64
Expand Down
13 changes: 12 additions & 1 deletion configs/models/tasks/s2ef.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,25 @@ default:
grad_input: atomic forces
train_on_free_atoms: True
eval_on_free_atoms: True
relax_dataset:
# path to lmdb of systems to be relaxed (uses same lmdbs as is2re)
src: /network/scratch/s/schmidtv/ocp/datasets/ocp/is2re/all/val_id/
write_pos: True
relaxation_steps: 300
relax_opt:
maxstep: 0.04
memory: 50
damping: 1.0
alpha: 70.0
traj_dir: "trajectories" # specify directory you wish to log the entire relaxations, suppress otherwise
normalizer: null
mode: train
optim:
optimizer: AdamW
model:
otf_graph: False
max_num_neighbors: 40
regress_forces: from_energy # can be in{ "from_energy", "direct", "direct_with_gradient_target" }
regress_forces: direct_with_gradient_target # can be in{ "from_energy", "direct", "direct_with_gradient_target" }
dataset:
default_val: val_id
train:
Expand Down
10 changes: 3 additions & 7 deletions mila/sbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
conda activate {env}
fi
{wandb_offline}
srun --gpus-per-task=1 --output={output} {python_command}
srun --output={output} {python_command}
"""


Expand Down Expand Up @@ -234,9 +234,7 @@ def load_sbatch_args_from_dir(dir):
k, v = (
line[2:]
if line.startswith("--")
else line[1:]
if line.startswith("-")
else line
else line[1:] if line.startswith("-") else line
).split("=")
sbatch_args[k] = v
args = {
Expand Down Expand Up @@ -280,9 +278,7 @@ def load_sbatch_args_from_dir(dir):
modules = (
[]
if not args.modules
else args.modules.split(",")
if isinstance(args.modules, str)
else args.modules
else args.modules.split(",") if isinstance(args.modules, str) else args.modules
)
if args.verbose:
args.pretty_print()
Expand Down
Loading