3. Basic usage

3.1 Compute energy

DMFF uses OpenMM to parse input files, including coordinates file, topology specification file and force field parameter file. Then, the core class Hamiltonian inherited from openmm.ForceField will be initialized and the method createPotential will be called to create differentiable potential energy functions for different energy terms. Take parametrzing an organic moleclue with GAFF2 force field as an example:

import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList

app.Topology.loadBondDefinitions("lig-top.xml")
pdb = app.PDBFile("lig.pdb")
ff = Hamiltonian("gaff-2.11.xml", "lig-prm.xml")
potentials = ff.createPotential(pdb.topology)
for pot in potentials:
    print(pot)

In this example, lig.pdb is the PDB file containing atomic coordinates, and lig-top.xml specifying bond connections within a molecule and this information is required by openmm.app to generate molecular topology. Note that this file is not always required, if bond conncections are defined in .pdb file by CONNECT keyword. gaff-2.11.xml contains GAFF2 force field parameters (bonds, angles, torsion and vdW), and lig-prm.xml contains atomic partial charges (GAFF2 requests a user-defined charge assignment process). This xml format is compatitable with OpenMM definitions, and a detailed description can be found in OpenMM user guide or XML-format force fields section.

If you run this script in examples/classical, you will get the following output.

<function HarmonicBondJaxGenerator.createForce.<locals>.potential_fn at 0x112504af0>
<function HarmonicAngleJaxGenerator.createForce.<locals>.potential_fn at 0x1124cd820>
<function PeriodicTorsionJaxGenerator.createForce.<locals>.potential_fn at 0x18509b790>
<function NonbondJaxGenerator.createForce.<locals>.potential_fn at 0x18509baf0>

The force field parameters are stored as a Python dict in the param attribute of force generators.

nbparam = ff.getGenerators()[3].params
nbparam
{
    'sigma': DeviceArray([0.33152124, ...], dtype=float32),
    'epsilon': DeviceArray([0.4133792, ...], dtype=float32),
    'epsfix': DeviceArray([], dtype=float32),
    'sigfix': DeviceArray([], dtype=float32),
    'charge': DeviceArray([-0.75401515, ...], dtype=float32),
    'coulomb14scale': DeviceArray([0.8333333], dtype=float32),
    'lj14scale': DeviceArray([0.5], dtype=float32)
}

Each generated function will read coordinates, box, pairs and force field parameters as inputs.

positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
box = jnp.array([
    [10.0,  0.0,  0.0],
    [ 0.0, 10.0,  0.0],
    [ 0.0,  0.0, 10.0]
])
nbList = NeighborList(box, rc=4)
nbList.allocate(positions)
pairs = nbList.pairs

Note that in order to take advantages of the auto-differentiable implementation in JAX, the input arrays have to be jax.numpy.ndarray, otherwise DMFF will raise an error. pairs is a \(N\times 2\) integer array in which each row specifying atoms condsidered as neighbors within rcut. As shown above, this can be calculated with dmff.NeighborList class which is supported by jax_md.

The potential energy function will give energy (a scalar, in kJ/mol) as output:

nbfunc = potentials[3]
nbene = nbfunc(positions, box, pairs, nbparam)
print(nbene)

If everything works fine, you will get -425.41412 as a result. In addition, you can also use getPotentialFunc() and getParameters() to obtain the whole potential energy function and force field parameter set, instead of seperated functions for different energy terms.

efunc = ff.getPotentialFunc()
params = ff.getParameters()
totene = efunc(positions, box, pairs, params)

3.2 Compute forces

Different from conventional programming frameworks, explicit definition of atomic force calculation functions are no longer needed. Instead, the forces can be evaluated in an automatic manner with jax.grad.

pos_grad_func = jax.grad(efunc, argnums=0)
force = -pos_grad_func(positions, box, pairs, params)

3.3 Compute parametric gradients

Similarly, the derivative of energy with regard to force field parameters can also be computed easily.

param_grad_func = jax.grad(nbfunc, argnums=-1)
pgrad = param_grad_func(positions, box, pairs, nbparam)
print(pgrad["charge"])
[ 652.7753      55.108738   729.36115   -171.4929     502.70837
  -44.917206   129.63994   -142.31796   -149.62088    453.21503
   46.372574   140.15303    575.488      461.46902    294.4358
  335.25153     27.828705   671.3637     390.8903     519.6835
  220.51129    238.7695     229.97302    210.58838    231.8734
  196.40994    237.08563     35.663574   457.76416     77.4798
  256.54382    402.2121     611.9573     440.8465     -52.09662
  421.86688    592.46265    237.98883    110.286194   150.65375
  218.61087    240.20477   -211.85376    150.7331     310.89404
  208.65228   -139.23026   -168.8883     114.3645       3.7261353
  399.6282     298.28455    422.06445    526.18463    521.27563
  575.85767    606.74744    394.40845    549.84033    556.4724
  485.1427     512.1267     558.55896    560.4667     562.812
  333.74194  ]