NeighborList

Source code in dmff/common/nblist.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class NeighborList:
    def __init__(self, box, r_cutoff, covalent_map, dr_threshold=0, capacity_multiplier=1.25, format=Literal['dense', 'sparse', ]) -> None:
        """wrapper of jax_md.space_periodic_general and jax_md.partition.NeighborList

        Args:
            box (jnp.ndarray): A (spatial_dim, spatial_dim) affine transformation or [lx, ly, lz] vector
            rc (float): cutoff radius
        """
        self.box = box
        self.rc = self.r_cutoff = r_cutoff

        self.dr_threshold = dr_threshold
        self.capacity_multiplier = capacity_multiplier

        self.covalent_map = covalent_map
        self.displacement_fn, self.shift_fn = space.periodic_general(
            box, fractional_coordinates=False
        )
        self.neighborlist_fn = partition.neighbor_list(
            self.displacement_fn, box, r_cutoff, dr_threshold, format=partition.OrderedSparse
        )
        self.nblist = None

    def allocate(self, positions: jnp.ndarray, box: Optional[jnp.ndarray] = None):
        """ A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.

        Args:
            positions (jnp.ndarray): particle positions

        Returns:
            jax_md.partition.NeighborList
        """
        if self.nblist is None:
            self.nblist = self.neighborlist_fn.allocate(positions)
        else:
            self.update(positions, box)
        return self.pairs

    def update(self, positions: jnp.ndarray, box: Optional[jnp.ndarray] = None):
        """ A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.

        Args:
            positions (jnp.ndarray): particle positions

        Returns:
            jax_md.partition.NeighborList
        """
        if box is None:
            self.nblist = self.nblist.update(positions)
        else:
            self.nblist = self.nblist.update(positions, box)
        return self.pairs

    @property
    def pairs(self):
        """get raw pair index

        Returns:
            jnp.ndarray: (nPairs, 2)
        """
        if self.nblist is None:
            raise RuntimeError("run nblist.allocate(positions) first")
        pairs = self.nblist.idx.T
        nbond = self.covalent_map[pairs[:, 0], pairs[:, 1]]
        return jnp.concatenate([pairs, nbond[:, None]], axis=1)

    @property
    def scaled_pairs(self):
        """get regularized pair index and mask

        Returns:
            (jnp.ndarray, jnp.ndarray): ((nParis, 2), (nPairs, ))
        """

        mask = jnp.sum(self.pairs[:, :2] == len(self.positions), axis=1)
        mask = jnp.logical_not(mask)
        pairs = regularize_pairs(self.pairs[:, :2])
        pairs = pairs[mask]
        nbond = self.covalent_map[pairs[:, 0], pairs[:, 1]]
        return jnp.concatenate([pairs, nbond[:, None]], axis=1)

    @property
    def positions(self):
        """get current positions in current neighborlist

        Returns:
            jnp.ndarray: (n, 3)
        """
        return self.nblist.reference_position

    @property
    def did_buffer_overflow(self) -> bool:
        """
        if the neighborlist buffer overflowed, return True

        Returns
        -------
        boolen
        """
        return self.nblist.did_buffer_overflow

__init__(box, r_cutoff, covalent_map, dr_threshold=0, capacity_multiplier=1.25, format=Literal['dense', 'sparse'])

wrapper of jax_md.space_periodic_general and jax_md.partition.NeighborList

Parameters:

Name Type Description Default
box jnp.ndarray

A (spatial_dim, spatial_dim) affine transformation or [lx, ly, lz] vector

required
rc float

cutoff radius

required
Source code in dmff/common/nblist.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(self, box, r_cutoff, covalent_map, dr_threshold=0, capacity_multiplier=1.25, format=Literal['dense', 'sparse', ]) -> None:
    """wrapper of jax_md.space_periodic_general and jax_md.partition.NeighborList

    Args:
        box (jnp.ndarray): A (spatial_dim, spatial_dim) affine transformation or [lx, ly, lz] vector
        rc (float): cutoff radius
    """
    self.box = box
    self.rc = self.r_cutoff = r_cutoff

    self.dr_threshold = dr_threshold
    self.capacity_multiplier = capacity_multiplier

    self.covalent_map = covalent_map
    self.displacement_fn, self.shift_fn = space.periodic_general(
        box, fractional_coordinates=False
    )
    self.neighborlist_fn = partition.neighbor_list(
        self.displacement_fn, box, r_cutoff, dr_threshold, format=partition.OrderedSparse
    )
    self.nblist = None

allocate(positions, box=None)

A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.

Parameters:

Name Type Description Default
positions jnp.ndarray

particle positions

required

Returns:

Type Description

jax_md.partition.NeighborList

Source code in dmff/common/nblist.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def allocate(self, positions: jnp.ndarray, box: Optional[jnp.ndarray] = None):
    """ A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.

    Args:
        positions (jnp.ndarray): particle positions

    Returns:
        jax_md.partition.NeighborList
    """
    if self.nblist is None:
        self.nblist = self.neighborlist_fn.allocate(positions)
    else:
        self.update(positions, box)
    return self.pairs

did_buffer_overflow() property

if the neighborlist buffer overflowed, return True

Returns

boolen

Source code in dmff/common/nblist.py
101
102
103
104
105
106
107
108
109
110
@property
def did_buffer_overflow(self) -> bool:
    """
    if the neighborlist buffer overflowed, return True

    Returns
    -------
    boolen
    """
    return self.nblist.did_buffer_overflow

pairs() property

get raw pair index

Returns:

Type Description

jnp.ndarray: (nPairs, 2)

Source code in dmff/common/nblist.py
64
65
66
67
68
69
70
71
72
73
74
75
@property
def pairs(self):
    """get raw pair index

    Returns:
        jnp.ndarray: (nPairs, 2)
    """
    if self.nblist is None:
        raise RuntimeError("run nblist.allocate(positions) first")
    pairs = self.nblist.idx.T
    nbond = self.covalent_map[pairs[:, 0], pairs[:, 1]]
    return jnp.concatenate([pairs, nbond[:, None]], axis=1)

positions() property

get current positions in current neighborlist

Returns:

Type Description

jnp.ndarray: (n, 3)

Source code in dmff/common/nblist.py
92
93
94
95
96
97
98
99
@property
def positions(self):
    """get current positions in current neighborlist

    Returns:
        jnp.ndarray: (n, 3)
    """
    return self.nblist.reference_position

scaled_pairs() property

get regularized pair index and mask

Returns:

Type Description
jnp.ndarray, jnp.ndarray

((nParis, 2), (nPairs, ))

Source code in dmff/common/nblist.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@property
def scaled_pairs(self):
    """get regularized pair index and mask

    Returns:
        (jnp.ndarray, jnp.ndarray): ((nParis, 2), (nPairs, ))
    """

    mask = jnp.sum(self.pairs[:, :2] == len(self.positions), axis=1)
    mask = jnp.logical_not(mask)
    pairs = regularize_pairs(self.pairs[:, :2])
    pairs = pairs[mask]
    nbond = self.covalent_map[pairs[:, 0], pairs[:, 1]]
    return jnp.concatenate([pairs, nbond[:, None]], axis=1)

update(positions, box=None)

A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.

Parameters:

Name Type Description Default
positions jnp.ndarray

particle positions

required

Returns:

Type Description

jax_md.partition.NeighborList

Source code in dmff/common/nblist.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def update(self, positions: jnp.ndarray, box: Optional[jnp.ndarray] = None):
    """ A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.

    Args:
        positions (jnp.ndarray): particle positions

    Returns:
        jax_md.partition.NeighborList
    """
    if box is None:
        self.nblist = self.nblist.update(positions)
    else:
        self.nblist = self.nblist.update(positions, box)
    return self.pairs