NeighborList

Source code in dmff/common/nblist.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
class NeighborList:

    def __init__(self, box, rc) -> None:
        """ wrapper of jax_md.space_periodic_general and jax_md.partition.NeighborList

        Args:
            box (jnp.ndarray): A (spatial_dim, spatial_dim) affine transformation or [lx, ly, lz] vector
            rc (float): cutoff radius
        """
        self.box = box
        self.rc = rc
        self.displacement_fn, self.shift_fn = space.periodic_general(box, fractional_coordinates=False)
        self.neighborlist_fn = partition.neighbor_list(self.displacement_fn, box, rc, 0, format=partition.OrderedSparse)
        self.nblist = None

    def allocate(self, positions: jnp.ndarray):
        """ A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.

        Args:
            positions (jnp.ndarray): particle positions

        Returns:
            jax_md.partition.NeighborList
        """
        if self.nblist is None:
            self.nblist = self.neighborlist_fn.allocate(positions)
        else:
            self.update(positions)
        return self.nblist

    def update(self, positions: jnp.ndarray):
        """ A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.

        Args:
            positions (jnp.ndarray): particle positions

        Returns:
            jax_md.partition.NeighborList
        """
        jit_deco = jit_condition()
        jit_deco(self.nblist.update)(positions)

        return self.nblist

    @property
    def pairs(self):
        """ get raw pair index

        Returns:
            jnp.ndarray: (nPairs, 2)
        """
        return self.nblist.idx.T

    @property
    def pair_mask(self):
        """ get regularized pair index and mask

        Returns:
            (jnp.ndarray, jnp.ndarray): ((nParis, 2), (nPairs, ))
        """

        mask = jnp.sum(self.pairs == len(self.positions), axis=1)
        mask = jnp.logical_not(mask)
        pair = regularize_pairs(self.pairs)

        return pair, mask

    @property
    def positions(self):
        """ get current positions in current neighborlist

        Returns:
            jnp.ndarray: (n, 3)
        """
        return self.nblist.reference_position

    @property
    def dr(self):
        """ get pair distance vector in current neighborlist

        Returns:
            jnp.ndarray: (nPairs, 3)
        """
        pair, _ = self.pair_mask
        return self.positions[pair[:, 0]] - self.positions[pair[:, 1]]

    @property
    def distance(self):
        """ get pair distance in current neighborlist

        Returns:
            jnp.ndarray: (nPairs, )

        """
        return jnp.linalg.norm(self.dr, axis=1)

__init__(box, rc)

wrapper of jax_md.space_periodic_general and jax_md.partition.NeighborList

Parameters:

Name Type Description Default
box jnp.ndarray

A (spatial_dim, spatial_dim) affine transformation or [lx, ly, lz] vector

required
rc float

cutoff radius

required
Source code in dmff/common/nblist.py
 9
10
11
12
13
14
15
16
17
18
19
20
def __init__(self, box, rc) -> None:
    """ wrapper of jax_md.space_periodic_general and jax_md.partition.NeighborList

    Args:
        box (jnp.ndarray): A (spatial_dim, spatial_dim) affine transformation or [lx, ly, lz] vector
        rc (float): cutoff radius
    """
    self.box = box
    self.rc = rc
    self.displacement_fn, self.shift_fn = space.periodic_general(box, fractional_coordinates=False)
    self.neighborlist_fn = partition.neighbor_list(self.displacement_fn, box, rc, 0, format=partition.OrderedSparse)
    self.nblist = None

allocate(positions)

A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.

Parameters:

Name Type Description Default
positions jnp.ndarray

particle positions

required

Returns:

Type Description

jax_md.partition.NeighborList

Source code in dmff/common/nblist.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def allocate(self, positions: jnp.ndarray):
    """ A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.

    Args:
        positions (jnp.ndarray): particle positions

    Returns:
        jax_md.partition.NeighborList
    """
    if self.nblist is None:
        self.nblist = self.neighborlist_fn.allocate(positions)
    else:
        self.update(positions)
    return self.nblist

distance() property

get pair distance in current neighborlist

Returns:

Type Description

jnp.ndarray: (nPairs, )

Source code in dmff/common/nblist.py
 93
 94
 95
 96
 97
 98
 99
100
101
@property
def distance(self):
    """ get pair distance in current neighborlist

    Returns:
        jnp.ndarray: (nPairs, )

    """
    return jnp.linalg.norm(self.dr, axis=1)

dr() property

get pair distance vector in current neighborlist

Returns:

Type Description

jnp.ndarray: (nPairs, 3)

Source code in dmff/common/nblist.py
83
84
85
86
87
88
89
90
91
@property
def dr(self):
    """ get pair distance vector in current neighborlist

    Returns:
        jnp.ndarray: (nPairs, 3)
    """
    pair, _ = self.pair_mask
    return self.positions[pair[:, 0]] - self.positions[pair[:, 1]]

pair_mask() property

get regularized pair index and mask

Returns:

Type Description
jnp.ndarray, jnp.ndarray

((nParis, 2), (nPairs, ))

Source code in dmff/common/nblist.py
60
61
62
63
64
65
66
67
68
69
70
71
72
@property
def pair_mask(self):
    """ get regularized pair index and mask

    Returns:
        (jnp.ndarray, jnp.ndarray): ((nParis, 2), (nPairs, ))
    """

    mask = jnp.sum(self.pairs == len(self.positions), axis=1)
    mask = jnp.logical_not(mask)
    pair = regularize_pairs(self.pairs)

    return pair, mask

pairs() property

get raw pair index

Returns:

Type Description

jnp.ndarray: (nPairs, 2)

Source code in dmff/common/nblist.py
51
52
53
54
55
56
57
58
@property
def pairs(self):
    """ get raw pair index

    Returns:
        jnp.ndarray: (nPairs, 2)
    """
    return self.nblist.idx.T

positions() property

get current positions in current neighborlist

Returns:

Type Description

jnp.ndarray: (n, 3)

Source code in dmff/common/nblist.py
74
75
76
77
78
79
80
81
@property
def positions(self):
    """ get current positions in current neighborlist

    Returns:
        jnp.ndarray: (n, 3)
    """
    return self.nblist.reference_position

update(positions)

A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.

Parameters:

Name Type Description Default
positions jnp.ndarray

particle positions

required

Returns:

Type Description

jax_md.partition.NeighborList

Source code in dmff/common/nblist.py
37
38
39
40
41
42
43
44
45
46
47
48
49
def update(self, positions: jnp.ndarray):
    """ A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.

    Args:
        positions (jnp.ndarray): particle positions

    Returns:
        jax_md.partition.NeighborList
    """
    jit_deco = jit_condition()
    jit_deco(self.nblist.update)(positions)

    return self.nblist