11.4. Run MD with JAX-MD#

Note

See Environment variables for the runtime environment variables.

DeePMD-kit provides a JAX-MD compatible interface for DeePMD models trained with the JAX backend. The interface adapts a DeePMD model to the usual JAX-MD style, where a neighbor-list factory and an energy function are passed to JAX-MD simulation routines.

The interface is available from deepmd.jax.jax_md.

11.4.1. Requirements#

Install DeePMD-kit with the JAX backend and install JAX-MD. The JAX-MD package is an optional runtime dependency and is not required for other DeePMD-kit interfaces.

11.4.2. Basic usage#

The most common entry point is as_jax_md, which returns a JAX-MD neighbor-list function and a potential energy function:

import jax
import jax.numpy as jnp
from jax_md import space

from deepmd.jax.jax_md import as_jax_md

box = jnp.asarray([12.4447, 12.4447, 12.4447])
coord = jnp.asarray(...)  # shape: (natoms, 3)
atom_types = jnp.asarray(...)  # shape: (natoms,), DeePMD type indexes

displacement_fn, shift_fn = space.periodic(box)
neighbor_fn, potential_fn = as_jax_md(
    "model.ckpt.jax",
    displacement_fn,
    box,
    atom_types,
    dr_threshold=0.2,
    capacity_multiplier=1.5,
)

neighbor = neighbor_fn.allocate(coord)
energy = potential_fn(coord, neighbor=neighbor)
force = -jax.grad(lambda x: potential_fn(x, neighbor=neighbor))(coord)

The returned potential_fn accepts a single-frame coordinate array with shape (natoms, 3) and returns the scalar total energy. The optional neighbor argument should be a dense JAX-MD neighbor list allocated by the returned neighbor_fn.

11.4.3. Running dynamics#

The potential can be used with JAX-MD simulation routines. A minimal NVE loop looks like:

from jax_md import simulate

K_B_EV_PER_K = 8.617333262145e-5
kT = K_B_EV_PER_K * 330.0
mass = jnp.ones((coord.shape[0], 1))

init_fn, step_fn = simulate.nve(potential_fn, shift_fn, dt=0.0005)
state = init_fn(jax.random.key(0), coord, kT=kT, mass=mass, neighbor=neighbor)

for _ in range(10):
    neighbor = neighbor_fn.update(state.position, neighbor)
    state = step_fn(state, neighbor=neighbor)

For a complete water example using the same 192-atom configuration as the LAMMPS example, see examples/water/jax_md.

11.4.4. Model files#

deepmd.jax.jax_md.load_model accepts:

  • a DeePMD JAX checkpoint path ending in .jax,

  • an already constructed JAX DeePMD model object.

The atom_types argument may be an integer array of DeePMD type indexes. It may also be a sequence of type names if the model has a type_map.

11.4.5. Neighbor lists#

The helper neighbor_list creates a dense JAX-MD neighbor-list function using the model cutoff:

from deepmd.jax.jax_md import energy_fn, neighbor_list

neighbor_fn = neighbor_list("model.ckpt.jax", displacement_fn, box)
potential_fn = energy_fn(
    "model.ckpt.jax",
    atom_types,
    box=box,
    displacement_fn=displacement_fn,
)

Only dense JAX-MD neighbor lists are currently supported. If the neighbor-list buffer overflows during a simulation, increase capacity_multiplier or rebuild the neighbor list with a larger capacity.

11.4.6. Units#

The JAX-MD interface does not perform unit conversion. Coordinates, box vectors, energies, forces, masses, and timesteps should be provided in units consistent with the DeePMD model and the chosen JAX-MD simulation setup.