3. Basic usage
3.1 Compute energy
DMFF uses OpenMM to parse input files, including coordinates file, topology specification file and force field parameter file. Then, the core class Hamiltonian inherited from openmm.ForceField will be initialized and the method createPotential will be called to create differentiable potential energy functions for different energy terms. Take parametrzing an organic moleclue with GAFF2 force field as an example:
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList
app.Topology.loadBondDefinitions("lig-top.xml")
pdb = app.PDBFile("lig.pdb")
ff = Hamiltonian("gaff-2.11.xml", "lig-prm.xml")
potentials = ff.createPotential(pdb.topology)
for pot in potentials:
print(pot)
In this example, lig.pdb is the PDB file containing atomic coordinates, and lig-top.xml specifying bond connections within a molecule and this information is required by openmm.app to generate molecular topology. Note that this file is not always required, if bond conncections are defined in .pdb file by CONNECT keyword. gaff-2.11.xml contains GAFF2 force field parameters (bonds, angles, torsion and vdW), and lig-prm.xml contains atomic partial charges (GAFF2 requests a user-defined charge assignment process). This xml format is compatitable with OpenMM definitions, and a detailed description can be found in OpenMM user guide or XML-format force fields section.
If you run this script in examples/classical, you will get the following output.
<function HarmonicBondJaxGenerator.createForce.<locals>.potential_fn at 0x112504af0>
<function HarmonicAngleJaxGenerator.createForce.<locals>.potential_fn at 0x1124cd820>
<function PeriodicTorsionJaxGenerator.createForce.<locals>.potential_fn at 0x18509b790>
<function NonbondJaxGenerator.createForce.<locals>.potential_fn at 0x18509baf0>
The force field parameters are stored as a Python dict in the param attribute of force generators.
nbparam = ff.getGenerators()[3].params
nbparam
{
'sigma': DeviceArray([0.33152124, ...], dtype=float32),
'epsilon': DeviceArray([0.4133792, ...], dtype=float32),
'epsfix': DeviceArray([], dtype=float32),
'sigfix': DeviceArray([], dtype=float32),
'charge': DeviceArray([-0.75401515, ...], dtype=float32),
'coulomb14scale': DeviceArray([0.8333333], dtype=float32),
'lj14scale': DeviceArray([0.5], dtype=float32)
}
Each generated function will read coordinates, box, pairs and force field parameters as inputs.
positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
box = jnp.array([
[10.0, 0.0, 0.0],
[ 0.0, 10.0, 0.0],
[ 0.0, 0.0, 10.0]
])
nbList = NeighborList(box, rc=4)
nbList.allocate(positions)
pairs = nbList.pairs
Note that in order to take advantages of the auto-differentiable implementation in JAX, the input arrays have to be jax.numpy.ndarray, otherwise DMFF will raise an error. pairs is a \(N\times 2\) integer array in which each row specifying atoms condsidered as neighbors within rcut. As shown above, this can be calculated with dmff.NeighborList class which is supported by jax_md.
The potential energy function will give energy (a scalar, in kJ/mol) as output:
nbfunc = potentials[3]
nbene = nbfunc(positions, box, pairs, nbparam)
print(nbene)
If everything works fine, you will get -425.41412 as a result. In addition, you can also use getPotentialFunc() and getParameters() to obtain the whole potential energy function and force field parameter set, instead of seperated functions for different energy terms.
efunc = ff.getPotentialFunc()
params = ff.getParameters()
totene = efunc(positions, box, pairs, params)
3.2 Compute forces
Different from conventional programming frameworks, explicit definition of atomic force calculation functions are no longer needed. Instead, the forces can be evaluated in an automatic manner with jax.grad.
pos_grad_func = jax.grad(efunc, argnums=0)
force = -pos_grad_func(positions, box, pairs, params)
3.3 Compute parametric gradients
Similarly, the derivative of energy with regard to force field parameters can also be computed easily.
param_grad_func = jax.grad(nbfunc, argnums=-1)
pgrad = param_grad_func(positions, box, pairs, nbparam)
print(pgrad["charge"])
[ 652.7753 55.108738 729.36115 -171.4929 502.70837
-44.917206 129.63994 -142.31796 -149.62088 453.21503
46.372574 140.15303 575.488 461.46902 294.4358
335.25153 27.828705 671.3637 390.8903 519.6835
220.51129 238.7695 229.97302 210.58838 231.8734
196.40994 237.08563 35.663574 457.76416 77.4798
256.54382 402.2121 611.9573 440.8465 -52.09662
421.86688 592.46265 237.98883 110.286194 150.65375
218.61087 240.20477 -211.85376 150.7331 310.89404
208.65228 -139.23026 -168.8883 114.3645 3.7261353
399.6282 298.28455 422.06445 526.18463 521.27563
575.85767 606.74744 394.40845 549.84033 556.4724
485.1427 512.1267 558.55896 560.4667 562.812
333.74194 ]