deepmd.dpmodel.utils.batch#

Normalize raw batches from DeepmdDataSystem into canonical format.

Attributes#

Functions#

normalize_batch(→ dict[str, Any])

Normalize a raw batch from DeepmdDataSystem to canonical format.

split_batch(→ tuple[dict[str, Any], dict[str, Any]])

Split a normalized batch into input and label dicts.

Module Contents#

deepmd.dpmodel.utils.batch._DROP_KEYS[source]#
deepmd.dpmodel.utils.batch._INPUT_KEYS[source]#
deepmd.dpmodel.utils.batch.normalize_batch(batch: dict[str, Any]) dict[str, Any][source]#

Normalize a raw batch from DeepmdDataSystem to canonical format.

The following conversions are applied:

  • "type" is renamed to "atype" (int64).

  • "natoms_vec" (1-D) is tiled to 2-D [nframes, 2+ntypes] and stored as "natoms".

  • find_* flags are converted to np.bool_.

  • Metadata keys (default_mesh, sid, fid) are dropped.

Parameters:
batchdict[str, Any]

Raw batch dict returned by DeepmdDataSystem.get_batch().

Returns:
dict[str, Any]

Normalized batch dict (new dict; the input is not mutated).

deepmd.dpmodel.utils.batch.split_batch(batch: dict[str, Any]) tuple[dict[str, Any], dict[str, Any]][source]#

Split a normalized batch into input and label dicts.

Parameters:
batchdict[str, Any]

Normalized batch (output of normalize_batch()).

Returns:
input_dictdict[str, Any]

Model inputs (coord, atype, box, fparam, aparam, spin).

label_dictdict[str, Any]

Labels and find flags (energy, force, virial, find_*, natoms, …).