Skip to content

Latest commit

 

History

History
80 lines (55 loc) · 3.85 KB

4.5Optimization.md

File metadata and controls

80 lines (55 loc) · 3.85 KB

Optimization

1. Theory

Automatic differentiation is a crucial component of DMFF and plays a significant role in optimizing neural networks. This technique computes the derivatives of output with respect to input using backpropagation, so parameters optimization can be conducted using gradient descent algorithms. With its efficiency in optimizing high-dimensional parameters, this technique is not limited to training neural networks but is also suitable for any physical model optimization (i.e., molecular force field in the case of DMFF). A typical optimization recipe needs two key ingradients: 1. gradient evaluation, which can be done easily using JAX; and 2. an optimizer that takes gradient as inputs, and update parameters following certain optimization algorithm. To help the users building optimization workflows, DMFF provides an wrapper API for optimizers implemented in Optax, which is introduced here.

2. Function module

Function periodic_move:

  • Creates a function to perform a periodic update on parameters. If the update causes the parameters to exceed a given range, they are wrapped around in a periodic manner (like an angle that wraps around after 360 degrees).

Function genOptimizer:

  • It's a function to generate an optimizer based on user preferences.
  • Depending on the arguments, it can produce various optimization schemes, such as SGD, Nesterov, Adam, etc.
  • Supports learning rate schedules like exponential decay and warmup exponential decay.
  • The optimizer can be further augmented with features like gradient clipping, periodic parameter wrapping, and keeping parameters non-negative.

Function label_iter, mark_iter, and label2trans_iter:

  • These are utility functions used for tree-like structured data (common with JAX's pytree concept).
  • label_iter recursively labels the tree nodes.
  • mark_iter marks each node in the tree with a False.
  • label2trans_iter checks and updates the mask tree based on whether the label exists in the transformations. If not, it sets that transformation to zero.

Class MultiTransform:

  • Manages multiple transformations simultaneously on tree-structured data.
  • Maintains a record of transformations, labels, and masks.
  • finalize method ensures that every label has a corresponding transformation, setting any missing transformations to zero.

3. How to use it

To set up an optimization, you should follow these steps:

  • Initialize MultiTransform with Parameter Tree:
multiTrans = MultiTransform(your_parameter_tree)
  • Define Optimizers for Desired Labels

  • For each part of the parameter tree you want to optimize differently, set an optimizer. For example:

multiTrans["Force1"]["Parameter1"] = genOptimizer(learning_rate=lr1, clip=clip1)
multiTrans["Force1"]["Parameter2"] = genOptimizer(learning_rate=lr2, clip=clip2)
  • Finalize MultiTransform

  • Before using it, always finalize the MultiTransform:

multiTrans.finalize()
  • Create a Combined Gradient Transformation:
traj = md.load("init.dcd", top="box.pdb")[50:]
  • Create a sample using the loaded trajectory and the previously defined state name:
grad_transform = optax.multi_transform(multiTrans.transforms, multiTrans.labels)
  • Mask Parameters (If Needed)

If you have parameters in your tree that shouldn't be updated, create a mask and then mask your transformation:

grad_transform = optax.masked(grad_transform, hamiltonian.getParameters().mask)
  • Initialize Optimization State:
opt_state = grad_transform.init(hamiltonian.getParameters().parameters)

By following the above steps, you'll have a grad_transform that can handle complex parameter trees and an optimization state opt_state ready for your optimization routine.