ADMPPmeForce

This is a convenient wrapper for multipolar PME calculations It wrapps all the environment parameters of multipolar PME calculation The so called "environment paramters" means parameters that do not need to be differentiable

Source code in dmff/admp/pme.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class ADMPPmeForce:
    '''
    This is a convenient wrapper for multipolar PME calculations
    It wrapps all the environment parameters of multipolar PME calculation
    The so called "environment paramters" means parameters that do not need to be differentiable
    '''

    def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None):
        '''
        Initialize the ADMPPmeForce calculator.

        Input:
            box: 
                (3, 3) float, box size in row
            axis_type:
                (na,) int, types of local axis (bisector, z-then-x etc.)
            rc: 
                float: cutoff distance
            ethresh: 
                float: pme energy threshold
            lmax:
                int: max L for multipoles
            lpol:
                bool: polarize or not?
            lpme:
                bool: do pme or simple cutoff? 
                if False, the kappa will be set to zero and the reciprocal part will not be computed
            steps:
                None or int: Whether do fixed number of dipole iteration steps?
                if None: converge dipoles until convergence threshold is met
                if int: optimize for this many steps and stop, this is useful if you want to jit the entire function

        Output:

        '''
        self.axis_type = axis_type
        self.axis_indices = axis_indices
        self.rc = rc
        self.ethresh = ethresh
        self.lmax = int(lmax)  # jichen: type checking
        # turn off pme if lpme is False, this is useful when doing cluster calculations
        self.lpme = lpme
        if self.lpme is False:
            self.kappa = 0
            self.K1 = 0
            self.K2 = 0
            self.K3 = 0
        else:
            kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
            self.kappa = kappa
            self.K1 = K1
            self.K2 = K2
            self.K3 = K3
        self.pme_order = 6
        self.lpol = lpol
        self.steps_pol = steps_pol
        # self.n_atoms = int(covalent_map.shape[0]) # len(axis_type)
        self.n_atoms = len(axis_type)

        # setup calculators
        self.refresh_calculators()
        return


    def generate_get_energy(self):
        # if the force field is not polarizable
        if not self.lpol:
            def get_energy(positions, box, pairs, Q_local, mScales):
                return energy_pme(positions, box, pairs,
                                 Q_local, None, None, None,
                                 mScales, None, None, 
                                 self.construct_local_frames, self.pme_recip,
                                 self.kappa, self.K1, self.K2, self.K3, self.lmax, False, lpme=self.lpme)
            return get_energy
        else:
            # this is the bare energy calculator, with Uind as explicit input
            def energy_fn(positions, box, pairs, Q_local, Uind_global, pol, tholes, mScales, pScales, dScales):
                return energy_pme(positions, box, pairs,
                                 Q_local, Uind_global, pol, tholes,
                                 mScales, pScales, dScales, 
                                 self.construct_local_frames, self.pme_recip,
                                 self.kappa, self.K1, self.K2, self.K3, self.lmax, True, lpme=self.lpme)
            self.energy_fn = energy_fn
            self.grad_U_fn = grad(self.energy_fn, argnums=(4))
            self.grad_pos_fn = grad(self.energy_fn, argnums=(0))
            self.U_ind = jnp.zeros((self.n_atoms, 3))
            # this is the wrapper that include a Uind optimizer
            def get_energy(
                    positions, box, pairs, 
                    Q_local, pol, tholes, mScales, pScales, dScales, 
                    U_init=self.U_ind):
                self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind(
                        positions, box, pairs, Q_local, pol, tholes, 
                        mScales, pScales, dScales, 
                        U_init=U_init, steps_pol=self.steps_pol)
                # here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr !
                # self.U_ind = jax.lax.stop_gradient(U_ind)
                return self.energy_fn(positions, box, pairs, Q_local, self.U_ind, pol, tholes, mScales, pScales, dScales)
            return get_energy


    def update_env(self, attr, val):
        '''
        Update the environment of the calculator
        '''
        setattr(self, attr, val)
        self.refresh_calculators()


    def refresh_calculators(self):
        '''
        refresh the energy and force calculators according to the current environment
        '''
        if self.lmax > 0:
            self.construct_local_frames = generate_construct_local_frames(self.axis_type, self.axis_indices)
        else:
            self.construct_local_frames = None
        lmax = self.lmax
        # for polarizable monopole force field, need to increase lmax to 1, accomodating induced dipoles
        if self.lmax == 0 and self.lpol is True:
            lmax = 1
        self.pme_recip = generate_pme_recip(Ck_1, self.kappa, False, self.pme_order, self.K1, self.K2, self.K3, lmax)
        # generate the force calculator
        self.get_energy = self.generate_get_energy()
        self.get_forces = value_and_grad(self.get_energy)
        return

    def optimize_Uind(self, 
            positions, box, pairs, 
            Q_local, pol, tholes, mScales, pScales, dScales, 
            U_init=None, steps_pol=None,
            maxiter=MAX_N_POL, thresh=POL_CONV):
        '''
        This function converges the induced dipole
        Note that we cut all the gradient chain passing through this function as we assume Feynman-Hellman theorem
        Gradients related to Uind should be dropped
        '''
        # Do not track gradient in Uind optimization
        positions = jax.lax.stop_gradient(positions)
        box = jax.lax.stop_gradient(box)
        Q_local = jax.lax.stop_gradient(Q_local)
        pol = jax.lax.stop_gradient(pol)
        tholes = jax.lax.stop_gradient(tholes)
        mScales = jax.lax.stop_gradient(mScales)
        pScales = jax.lax.stop_gradient(pScales)
        dScales = jax.lax.stop_gradient(dScales)
        if U_init is None:
            U = jnp.zeros((self.n_atoms, 3))
        else:
            U = U_init
        if steps_pol is None:
            site_filter = (pol>0.001) # focus on the actual polarizable sites

        if steps_pol is None:
            for i in range(maxiter):
                field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
                # E = self.energy_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
                if jnp.max(jnp.abs(field[site_filter])) < thresh:
                    break
                U = U - field * pol[:, jnp.newaxis] / DIELECTRIC
            if i == maxiter-1:
                flag = False
            else: # converged
                flag = True
        else:
            def update_U(i, U):
                field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
                U = U - field * pol[:, jnp.newaxis] / DIELECTRIC
                return U
            U = jax.lax.fori_loop(0, steps_pol, update_U, U)
            flag = True
        return U, flag, steps_pol

__init__(box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None)

Initialize the ADMPPmeForce calculator.

Input

box: (3, 3) float, box size in row axis_type: (na,) int, types of local axis (bisector, z-then-x etc.) rc: float: cutoff distance ethresh: float: pme energy threshold lmax: int: max L for multipoles lpol: bool: polarize or not? lpme: bool: do pme or simple cutoff? if False, the kappa will be set to zero and the reciprocal part will not be computed steps: None or int: Whether do fixed number of dipole iteration steps? if None: converge dipoles until convergence threshold is met if int: optimize for this many steps and stop, this is useful if you want to jit the entire function

Output:

Source code in dmff/admp/pme.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None):
    '''
    Initialize the ADMPPmeForce calculator.

    Input:
        box: 
            (3, 3) float, box size in row
        axis_type:
            (na,) int, types of local axis (bisector, z-then-x etc.)
        rc: 
            float: cutoff distance
        ethresh: 
            float: pme energy threshold
        lmax:
            int: max L for multipoles
        lpol:
            bool: polarize or not?
        lpme:
            bool: do pme or simple cutoff? 
            if False, the kappa will be set to zero and the reciprocal part will not be computed
        steps:
            None or int: Whether do fixed number of dipole iteration steps?
            if None: converge dipoles until convergence threshold is met
            if int: optimize for this many steps and stop, this is useful if you want to jit the entire function

    Output:

    '''
    self.axis_type = axis_type
    self.axis_indices = axis_indices
    self.rc = rc
    self.ethresh = ethresh
    self.lmax = int(lmax)  # jichen: type checking
    # turn off pme if lpme is False, this is useful when doing cluster calculations
    self.lpme = lpme
    if self.lpme is False:
        self.kappa = 0
        self.K1 = 0
        self.K2 = 0
        self.K3 = 0
    else:
        kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
        self.kappa = kappa
        self.K1 = K1
        self.K2 = K2
        self.K3 = K3
    self.pme_order = 6
    self.lpol = lpol
    self.steps_pol = steps_pol
    # self.n_atoms = int(covalent_map.shape[0]) # len(axis_type)
    self.n_atoms = len(axis_type)

    # setup calculators
    self.refresh_calculators()
    return

optimize_Uind(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=None, steps_pol=None, maxiter=MAX_N_POL, thresh=POL_CONV)

This function converges the induced dipole Note that we cut all the gradient chain passing through this function as we assume Feynman-Hellman theorem Gradients related to Uind should be dropped

Source code in dmff/admp/pme.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def optimize_Uind(self, 
        positions, box, pairs, 
        Q_local, pol, tholes, mScales, pScales, dScales, 
        U_init=None, steps_pol=None,
        maxiter=MAX_N_POL, thresh=POL_CONV):
    '''
    This function converges the induced dipole
    Note that we cut all the gradient chain passing through this function as we assume Feynman-Hellman theorem
    Gradients related to Uind should be dropped
    '''
    # Do not track gradient in Uind optimization
    positions = jax.lax.stop_gradient(positions)
    box = jax.lax.stop_gradient(box)
    Q_local = jax.lax.stop_gradient(Q_local)
    pol = jax.lax.stop_gradient(pol)
    tholes = jax.lax.stop_gradient(tholes)
    mScales = jax.lax.stop_gradient(mScales)
    pScales = jax.lax.stop_gradient(pScales)
    dScales = jax.lax.stop_gradient(dScales)
    if U_init is None:
        U = jnp.zeros((self.n_atoms, 3))
    else:
        U = U_init
    if steps_pol is None:
        site_filter = (pol>0.001) # focus on the actual polarizable sites

    if steps_pol is None:
        for i in range(maxiter):
            field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
            # E = self.energy_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
            if jnp.max(jnp.abs(field[site_filter])) < thresh:
                break
            U = U - field * pol[:, jnp.newaxis] / DIELECTRIC
        if i == maxiter-1:
            flag = False
        else: # converged
            flag = True
    else:
        def update_U(i, U):
            field = self.grad_U_fn(positions, box, pairs, Q_local, U, pol, tholes, mScales, pScales, dScales)
            U = U - field * pol[:, jnp.newaxis] / DIELECTRIC
            return U
        U = jax.lax.fori_loop(0, steps_pol, update_U, U)
        flag = True
    return U, flag, steps_pol

refresh_calculators()

refresh the energy and force calculators according to the current environment

Source code in dmff/admp/pme.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def refresh_calculators(self):
    '''
    refresh the energy and force calculators according to the current environment
    '''
    if self.lmax > 0:
        self.construct_local_frames = generate_construct_local_frames(self.axis_type, self.axis_indices)
    else:
        self.construct_local_frames = None
    lmax = self.lmax
    # for polarizable monopole force field, need to increase lmax to 1, accomodating induced dipoles
    if self.lmax == 0 and self.lpol is True:
        lmax = 1
    self.pme_recip = generate_pme_recip(Ck_1, self.kappa, False, self.pme_order, self.K1, self.K2, self.K3, lmax)
    # generate the force calculator
    self.get_energy = self.generate_get_energy()
    self.get_forces = value_and_grad(self.get_energy)
    return

update_env(attr, val)

Update the environment of the calculator

Source code in dmff/admp/pme.py
139
140
141
142
143
144
def update_env(self, attr, val):
    '''
    Update the environment of the calculator
    '''
    setattr(self, attr, val)
    self.refresh_calculators()

calc_e_ind(dr, thole1, thole2, dmp, pscales, dscales, kappa, lmax=2)

This function calculates the eUindCoefs at once ## compute the Thole damping factors for energies eUindCoefs is basically the interaction tensor between permanent multipole components and induced dipoles Everything should be done in the so called quasi-internal (qi) frame

Inputs

dr: float: distance between one pair of particles dmp float: damping factors between one pair of particles mscales: float: scaling factor between permanent - permanent multipole interactions, for each pair pscales: float: scaling factor between permanent - induced multipole interactions, for each pair au: float: for damping factors kappa: float: \kappa in PME, unit in A^-1 lmax: int: max L

Output

Interaction tensors components

Source code in dmff/admp/pme.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
@jit_condition(static_argnums=(7))
def calc_e_ind(dr, thole1, thole2, dmp, pscales, dscales, kappa, lmax=2):

    r'''
    This function calculates the eUindCoefs at once
       ## compute the Thole damping factors for energies
     eUindCoefs is basically the interaction tensor between permanent multipole components and induced dipoles
    Everything should be done in the so called quasi-internal (qi) frame


    Inputs:
        dr: 
            float: distance between one pair of particles
        dmp
            float: damping factors between one pair of particles
        mscales:
            float: scaling factor between permanent - permanent multipole interactions, for each pair
        pscales:
            float: scaling factor between permanent - induced multipole interactions, for each pair
        au:
            float: for damping factors
        kappa:
            float: \kappa in PME, unit in A^-1
        lmax:
            int: max L

    Output:
        Interaction tensors components
    '''
    ## pscale == 0 ? thole1 + thole2 : DEFAULT_THOLE_WIDTH
    w = jnp.heaviside(pscales, 0)
    a = w * DEFAULT_THOLE_WIDTH + (1-w) * (thole1+thole2)

    dmp = trim_val_0(dmp)
    u = trim_val_infty(dr/dmp)

    ## au <= 50 aupi = au ;au> 50 aupi = 50
    au = a * u
    expau = jnp.piecewise(au, [au<50, au>=50], [lambda au: jnp.exp(-au), lambda au: jnp.array(0)])

    ## compute the Thole damping factors for energies
    au2 = trim_val_infty(au*au)
    au3 = trim_val_infty(au2*au)
    au4 = trim_val_infty(au3*au)
    au5 = trim_val_infty(au4*au)
    au6 = trim_val_infty(au5*au)

    ##  Thole damping factors for energies
    thole_c   = 1.0 - expau*(1.0 + au + 0.5*au2)
    thole_d0  = 1.0 - expau*(1.0 + au + 0.5*au2 + au3/4.0)
    thole_d1  = 1.0 - expau*(1.0 + au + 0.5*au2)
    thole_q0  = 1.0 - expau*(1.0 + au + 0.5*au2 + au3/6.0 + au4/18.0)
    thole_q1  = 1.0 - expau*(1.0 + au + 0.5*au2 + au3/6.0)
    # copied from calc_e_perm
    # be aware of unit and dimension !!
    rInv = 1 / dr
    rInvVec = jnp.array([DIELECTRIC*(rInv**i) for i in range(0, 9)])
    alphaRVec = jnp.array([(kappa*dr)**i for i in range(0, 10)])
    X = 2 * jnp.exp(-alphaRVec[2]) / jnp.sqrt(np.pi)
    tmp = jnp.array(alphaRVec[1])
    doubleFactorial = 1
    facCount = 1
    erfAlphaR = erf(alphaRVec[1])

    #bVec = jnp.empty((6, len(erfAlphaR)))
    bVec = jnp.empty(6)

    bVec = bVec.at[1].set(-erfAlphaR)
    for i in range(2, 6):
        bVec = bVec.at[i].set((bVec[i-1]+(tmp*X/doubleFactorial)))
        facCount += 2
        doubleFactorial *= facCount
        tmp *= 2 * alphaRVec[2]

    ## C-Uind 
    cud = 2.0*rInvVec[2]*(pscales*thole_c + bVec[2])
    if lmax >= 1:
        ##  D-Uind terms 
        dud_m0 = -2.0*2.0/3.0*rInvVec[3]*(3.0*(pscales*thole_d0 + bVec[3]) + alphaRVec[3]*X)
        dud_m1 = 2.0*rInvVec[3]*(pscales*thole_d1 + bVec[3] - 2.0/3.0*alphaRVec[3]*X)
    else:
        dud_m0 = 0.0
        dud_m1 = 0.0

    if lmax >= 2:
        ## Uind-Q
        udq_m0 = 2.0*rInvVec[4]*(3.0*(pscales*thole_q0 + bVec[3]) + 4/3*alphaRVec[5]*X)
        udq_m1 =  -2.0*jnp.sqrt(3)*rInvVec[4]*(pscales*thole_q1 + bVec[3])
    else:
        udq_m0 = 0.0
        udq_m1 = 0.0
    ## Uind-Uind
    udud_m0 = -2.0/3.0*rInvVec[3]*(3.0*(dscales*thole_d0 + bVec[3]) + alphaRVec[3]*X)
    udud_m1 = rInvVec[3]*(dscales*thole_d1 + bVec[3] - 2.0/3.0*alphaRVec[3]*X)
    return cud, dud_m0, dud_m1, udq_m0, udq_m1, udud_m0, udud_m1

calc_e_perm(dr, mscales, kappa, lmax=2)

This function calculates the ePermCoefs at once ePermCoefs is basically the interaction tensor between permanent multipole components Everything should be done in the so called quasi-internal (qi) frame Energy = \sum_ij qiQI * ePermCoeff_ij * qiQJ

Inputs

dr: float: distance between one pair of particles mscales: float: scaling factor between permanent - permanent multipole interactions, for each pair kappa: float: \kappa in PME, unit in A^-1 lmax: int: max L

Output

cc, cd, dd_m0, dd_m1, cq, dq_m0, dq_m1, qq_m0, qq_m1, qq_m2: n * 1 array: ePermCoefs

Source code in dmff/admp/pme.py
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
@jit_condition(static_argnums=(3))
def calc_e_perm(dr, mscales, kappa, lmax=2):

    r'''
    This function calculates the ePermCoefs at once
    ePermCoefs is basically the interaction tensor between permanent multipole components
    Everything should be done in the so called quasi-internal (qi) frame
    Energy = \sum_ij qiQI * ePermCoeff_ij * qiQJ

    Inputs:
        dr: 
            float: distance between one pair of particles
        mscales:
            float: scaling factor between permanent - permanent multipole interactions, for each pair
        kappa:
            float: \kappa in PME, unit in A^-1
        lmax:
            int: max L

    Output:
        cc, cd, dd_m0, dd_m1, cq, dq_m0, dq_m1, qq_m0, qq_m1, qq_m2:
            n * 1 array: ePermCoefs
    '''

    # be aware of unit and dimension !!
    rInv = 1 / dr
    rInvVec = jnp.array([DIELECTRIC*(rInv**i) for i in range(0, 9)])
    alphaRVec = jnp.array([(kappa*dr)**i for i in range(0, 10)])
    X = 2 * jnp.exp(-alphaRVec[2]) / jnp.sqrt(np.pi)
    tmp = jnp.array(alphaRVec[1])
    doubleFactorial = 1
    facCount = 1
    erfAlphaR = erf(alphaRVec[1])

    # bVec = jnp.empty((6, len(erfAlphaR)))
    bVec = jnp.empty(6)

    bVec = bVec.at[1].set(-erfAlphaR)
    for i in range(2, 6):
        bVec = bVec.at[i].set((bVec[i-1]+(tmp*X/doubleFactorial)))
        facCount += 2
        doubleFactorial *= facCount
        tmp *= 2 * alphaRVec[2]    

    # C-C: 1
    cc = rInvVec[1] * (mscales + bVec[2] - alphaRVec[1]*X)
    if lmax >= 1:
        # C-D
        cd = rInvVec[2] * (mscales + bVec[2])
        # D-D: 2
        dd_m0 = -2/3 * rInvVec[3] * (3*(mscales + bVec[3]) + alphaRVec[3]*X)
        dd_m1 = rInvVec[3] * (mscales + bVec[3] - (2/3)*alphaRVec[3]*X)
    else:
        cd = 0
        dd_m0 = 0
        dd_m1 = 0

    if lmax >= 2:
        ## C-Q: 1
        cq = (mscales + bVec[3]) * rInvVec[3]
        ## D-Q: 2
        dq_m0 = rInvVec[4] * (3* (mscales + bVec[3]) + (4/3) * alphaRVec[5]*X)
        dq_m1 = -jnp.sqrt(3) * rInvVec[4] * (mscales + bVec[3])
        ## Q-Q
        qq_m0 = rInvVec[5] * (6* (mscales + bVec[4]) + (4/45)* (-3 + 10*alphaRVec[2]) * alphaRVec[5]*X)
        qq_m1 = - (4/15) * rInvVec[5] * (15*(mscales+bVec[4]) + alphaRVec[5]*X)
        qq_m2 = rInvVec[5] * (mscales + bVec[4] - (4/15)*alphaRVec[5]*X)
    else:
        cq = 0
        dq_m0 = 0
        dq_m1 = 0
        qq_m0 = 0
        qq_m1 = 0
        qq_m1 = 0
        qq_m2 = 0

    return cc, cd, dd_m0, dd_m1, cq, dq_m0, dq_m1, qq_m0, qq_m1, qq_m2

energy_pme(positions, box, pairs, Q_local, Uind_global, pol, tholes, mScales, pScales, dScales, construct_local_frame_fn, pme_recip_fn, kappa, K1, K2, K3, lmax, lpol, lpme=True)

This is the top-level wrapper for multipole PME

Input

positions: Na * 3: positions box: 3 * 3: box Q_local: Na * (lmax+1)^2: harmonic multipoles of each site in local frame Uind_global: Na * 3: the induced dipole moment, in GLOBAL CARTESIAN! pol: (Na,) float: the polarizability of each site, unit in A**3 tholes: (Na,) float: the thole damping widths for each atom, it's dimensionless, default is 8 according to MPID paper mScales, pScale, dScale: (Nexcl,): multipole-multipole interaction exclusion scalings: 1-2, 1-3 ... for permanent-permanent, permanent-induced, induced-induced interactions pairs: Np * 3: interacting pair indices and topology distance covalent_map: Na * Na: topological distances between atoms, if i, j are topologically distant, then covalent_map[i, j] == 0 construct_local_frame_fn: function: local frame constructors, from generate_local_frame_constructor pme_recip: function: see recip.py, a reciprocal space calculator kappa: float: kappa in A^-1 K1, K2, K3: int: max K for reciprocal calculations lmax: int: maximum L lpol: bool: if polarizable or not? if yes, 1, otherwise 0 lpme: bool: doing pme? If false, then turn off reciprocal space and set kappa = 0

Output

energy: total pme energy

Source code in dmff/admp/pme.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def energy_pme(positions, box, pairs,
        Q_local, Uind_global, pol, tholes,
        mScales, pScales, dScales, 
        construct_local_frame_fn, pme_recip_fn, kappa, K1, K2, K3, lmax, lpol, lpme=True):
    '''
    This is the top-level wrapper for multipole PME

    Input:
        positions:
            Na * 3: positions
        box: 
            3 * 3: box
        Q_local: 
            Na * (lmax+1)^2: harmonic multipoles of each site in local frame
        Uind_global:
            Na * 3: the induced dipole moment, in GLOBAL CARTESIAN!
        pol: 
            (Na,) float: the polarizability of each site, unit in A**3
        tholes: 
            (Na,) float: the thole damping widths for each atom, it's dimensionless, default is 8 according to MPID paper
        mScales, pScale, dScale:
            (Nexcl,): multipole-multipole interaction exclusion scalings: 1-2, 1-3 ...
            for permanent-permanent, permanent-induced, induced-induced interactions
        pairs:
            Np * 3: interacting pair indices and topology distance
        covalent_map:
            Na * Na: topological distances between atoms, if i, j are topologically distant, then covalent_map[i, j] == 0
        construct_local_frame_fn:
            function: local frame constructors, from generate_local_frame_constructor
        pme_recip:
            function: see recip.py, a reciprocal space calculator
        kappa:
            float: kappa in A^-1
        K1, K2, K3:
            int: max K for reciprocal calculations
        lmax:
            int: maximum L
        lpol:
            bool: if polarizable or not? if yes, 1, otherwise 0
        lpme:
            bool: doing pme? If false, then turn off reciprocal space and set kappa = 0

    Output:
        energy: total pme energy
    '''
    # if doing a multipolar calculation
    if lmax > 0:
        local_frames = construct_local_frame_fn(positions, box)
        Q_global = rot_local2global(Q_local, local_frames, lmax)
    else:
        if lpol:
            # if fixed multipole only contains charge, and it's polarizable, then expand Q matrix
            dips = jnp.zeros((Q_local.shape[0], 3))
            Q_global = jnp.hstack((Q_local, dips))
            lmax = 1
        else:
            Q_global = Q_local

    # note we assume when lpol is True, lmax should be >= 1
    if lpol:
        # convert Uind to global harmonics, in accord with Q_global
        U_ind = C1_c2h.dot(Uind_global.T).T
        Q_global_tot = Q_global.at[:, 1:4].add(U_ind)
    else:
        Q_global_tot = Q_global

    if lpme is False:
        kappa = 0

    if lpol:
        ene_real = pme_real(positions, box, pairs, Q_global, U_ind, pol, tholes, 
                           mScales, pScales, dScales, kappa, lmax, True)
    else:
        ene_real = pme_real(positions, box, pairs, Q_global, None, None, None,
                           mScales, None, None, kappa, lmax, False)

    if lpme:
        ene_recip = pme_recip_fn(positions, box, Q_global_tot)
        ene_self = pme_self(Q_global_tot, kappa, lmax)

        if lpol:
            ene_self += pol_penalty(U_ind, pol)
        return ene_real + ene_recip + ene_self

    else:
        if lpol:
            ene_self = pol_penalty(U_ind, pol)
        else:
            ene_self = 0.0
        return ene_real + ene_self

gen_trim_val_0(thresh)

Trim the value at zero point to avoid singularity

Source code in dmff/admp/pme.py
469
470
471
472
473
474
475
476
477
478
def gen_trim_val_0(thresh):
    '''
    Trim the value at zero point to avoid singularity
    '''
    def trim_val_0(x):
        return jnp.piecewise(x, [x<thresh, x>=thresh], [lambda x: jnp.array(thresh), lambda x: x])
    if DO_JIT:
        return jit(trim_val_0)
    else:
        return trim_val_0

gen_trim_val_infty(thresh)

Trime the value at infinity to avoid divergence

Source code in dmff/admp/pme.py
483
484
485
486
487
488
489
490
491
492
def gen_trim_val_infty(thresh):
    '''
    Trime the value at infinity to avoid divergence
    '''
    def trim_val_infty(x):
        return jnp.piecewise(x, [x<thresh, x>=thresh], [lambda x: x, lambda x: jnp.array(thresh)])
    if DO_JIT:
        return jit(trim_val_infty)
    else:
        return trim_val_infty

pme_real(positions, box, pairs, Q_global, Uind_global, pol, tholes, mScales, pScales, dScales, kappa, lmax, lpol)

This is the real space PME calculate function NOTE: only deals with permanent-permanent multipole interactions It expands the pairwise parameters, and then invoke pme_real_kernel It seems pointless to jit it: 1. the heavy-lifting kernel function is jitted and vmapped 2. len(pairs) keeps changing throughout the simulation, the function would just recompile everytime

Input

positions: Na * 3: positions box: 3 * 3: box, axes arranged in row pairs: Np * 3: interacting pair indices and topology distance Q_global: Na * (l+1)**2: harmonics multipoles of each atom, in global frame Uind_global: Na * 3: harmonic induced dipoles, in global frame pol: (Na,): polarizabilities tholes: (Na,): thole damping parameters mScales: (Nexcl,): permanent multipole-multipole interaction exclusion scalings: 1-2, 1-3 ... covalent_map: Na * Na: topological distances between atoms, if i, j are topologically distant, then covalent_map[i, j] == 0 kappa: float: kappa in A^-1 lmax: int: maximum L lpol: Bool: whether do a polarizable calculation?

Output

ene: pme realspace energy

Source code in dmff/admp/pme.py
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def pme_real(positions, box, pairs, 
        Q_global, Uind_global, pol, tholes,
        mScales, pScales, dScales,
        kappa, lmax, lpol):
    '''
    This is the real space PME calculate function
    NOTE: only deals with permanent-permanent multipole interactions
    It expands the pairwise parameters, and then invoke pme_real_kernel
    It seems pointless to jit it:
    1. the heavy-lifting kernel function is jitted and vmapped
    2. len(pairs) keeps changing throughout the simulation, the function would just recompile everytime

    Input:
        positions:
            Na * 3: positions
        box:
            3 * 3: box, axes arranged in row
        pairs:
            Np * 3: interacting pair indices and topology distance
        Q_global:
            Na * (l+1)**2: harmonics multipoles of each atom, in global frame
        Uind_global:
            Na * 3: harmonic induced dipoles, in global frame
        pol:
            (Na,): polarizabilities
        tholes:
            (Na,): thole damping parameters
        mScales:
            (Nexcl,): permanent multipole-multipole interaction exclusion scalings: 1-2, 1-3 ...
        covalent_map:
            Na * Na: topological distances between atoms, if i, j are topologically distant, then covalent_map[i, j] == 0
        kappa:
            float: kappa in A^-1
        lmax:
            int: maximum L
        lpol:
            Bool: whether do a polarizable calculation?

    Output:
        ene: pme realspace energy
    '''
    pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))
    buffer_scales = pair_buffer_scales(pairs[:, :2])
    box_inv = jnp.linalg.inv(box)
    r1 = distribute_v3(positions, pairs[:, 0])
    r2 = distribute_v3(positions, pairs[:, 1])
    Q_extendi = distribute_multipoles(Q_global, pairs[:, 0])
    Q_extendj = distribute_multipoles(Q_global, pairs[:, 1])
    nbonds = pairs[:, 2]
    #nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
    indices = nbonds-1
    mscales = distribute_scalar(mScales, indices)
    mscales = mscales * buffer_scales
    if lpol:
        pol1 = distribute_scalar(pol, pairs[:, 0])
        pol2 = distribute_scalar(pol, pairs[:, 1])
        thole1 = distribute_scalar(tholes, pairs[:, 0])
        thole2 = distribute_scalar(tholes, pairs[:, 1])
        Uind_extendi = distribute_v3(Uind_global, pairs[:, 0])
        Uind_extendj = distribute_v3(Uind_global, pairs[:, 1])
        pscales = distribute_scalar(pScales, indices)
        pscales = pscales * buffer_scales
        dscales = distribute_scalar(dScales, indices)
        dscales = dscales * buffer_scales
        dmp = get_pair_dmp(pol1, pol2)
    else:
        Uind_extendi = None
        Uind_extendj = None
        pscales = None
        dscales = None
        thole1 = None
        thole2 = None
        dmp = None

    # deals with geometries
    dr = r1 - r2
    dr = v_pbc_shift(dr, box, box_inv)
    norm_dr = jnp.linalg.norm(dr, axis=-1)
    Ri = build_quasi_internal(r1, r2, dr, norm_dr)
    qiQI = rot_global2local(Q_extendi, Ri, lmax)
    qiQJ = rot_global2local(Q_extendj, Ri, lmax)
    if lpol:
        qiUindI = rot_ind_global2local(Uind_extendi, Ri)
        qiUindJ = rot_ind_global2local(Uind_extendj, Ri)
    else:
        qiUindI = None
        qiUindJ = None

    # everything should be pair-specific now
    ene = jnp.sum(
        pme_real_kernel(
            norm_dr, 
            qiQI, 
            qiQJ, 
            qiUindI, 
            qiUindJ, 
            thole1, 
            thole2, 
            dmp, 
            mscales, 
            pscales, 
            dscales, 
            kappa, 
            lmax, 
            lpol
        ) * buffer_scales
    )

    return ene

pme_real_kernel(dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscales, pscales, dscales, kappa, lmax=2, lpol=False)

This is the heavy-lifting kernel function to compute the realspace multipolar PME Vectorized over interacting pairs

Input

dr: float, the interatomic distances, (np) array if vectorized qiQI: [(lmax+1)^2] float array, the harmonic multipoles of site i in quasi-internal frame qiQJ: [(lmax+1)^2] float array, the harmonic multipoles of site j in quasi-internal frame qiUindI (3,) float array, the harmonic dipoles of site i in QI frame qiUindJ (3,) float array, the harmonic dipoles of site j in QI frame thole1 float: thole damping coeff of site i thole2 float: thole damping coeff of site j dmp: float: (pol1 * pol2)**1/6, distance rescaling params used in thole damping mscale: float, scaling factor between interacting sites (permanent-permanent) pscale: float, scaling factor between perm-ind interaction dscale: float, scaling factor between ind-ind interaction kappa: float, kappa in unit A^1 lmax: int, maximum angular momentum lpol: bool, doing polarization?

Output

energy: float, realspace interaction energy between the sites

Source code in dmff/admp/pme.py
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
@partial(vmap, in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None), out_axes=0)
@jit_condition(static_argnums=(12, 13))
def pme_real_kernel(dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscales, pscales, dscales, kappa, lmax=2, lpol=False):
    '''
    This is the heavy-lifting kernel function to compute the realspace multipolar PME 
    Vectorized over interacting pairs

    Input:
        dr: 
            float, the interatomic distances, (np) array if vectorized
        qiQI:
            [(lmax+1)^2] float array, the harmonic multipoles of site i in quasi-internal frame
        qiQJ:
            [(lmax+1)^2] float array, the harmonic multipoles of site j in quasi-internal frame
        qiUindI
            (3,) float array, the harmonic dipoles of site i in QI frame
        qiUindJ
            (3,) float array, the harmonic dipoles of site j in QI frame
        thole1
            float: thole damping coeff of site i
        thole2
            float: thole damping coeff of site j
        dmp:
            float: (pol1 * pol2)**1/6, distance rescaling params used in thole damping
        mscale:
            float, scaling factor between interacting sites (permanent-permanent)
        pscale:
            float, scaling factor between perm-ind interaction
        dscale:
            float, scaling factor between ind-ind interaction
        kappa:
            float, kappa in unit A^1
        lmax:
            int, maximum angular momentum
        lpol:
            bool, doing polarization?

    Output:
        energy: 
            float, realspace interaction energy between the sites
    '''

    cc, cd, dd_m0, dd_m1, cq, dq_m0, dq_m1, qq_m0, qq_m1, qq_m2 = calc_e_perm(dr, mscales, kappa, lmax)
    if lpol:
        cud, dud_m0, dud_m1, udq_m0, udq_m1, udud_m0, udud_m1 = calc_e_ind(dr, thole1, thole2, dmp, pscales, dscales, kappa, lmax)

    Vij0 = cc*qiQI[0]
    Vji0 = cc*qiQJ[0]
    # C-Uind
    if lpol: 
        Vij0 -= cud * qiUindI[0]
        Vji0 += cud * qiUindJ[0]

    if lmax >= 1:
        # C-D 
        Vij0 = Vij0 - cd*qiQI[1]
        Vji1 = -cd*qiQJ[0]
        Vij1 = cd*qiQI[0]
        Vji0 = Vji0 + cd*qiQJ[1]
        # D-D m0 
        Vij1 += dd_m0 * qiQI[1]
        Vji1 += dd_m0 * qiQJ[1]    
        # D-D m1 
        Vij2 = dd_m1*qiQI[2]
        Vji2 = dd_m1*qiQJ[2]
        Vij3 = dd_m1*qiQI[3]
        Vji3 = dd_m1*qiQJ[3]
        # D-Uind
        if lpol:
            Vij1 += dud_m0 * qiUindI[0]
            Vji1 += dud_m0 * qiUindJ[0]
            Vij2 += dud_m1 * qiUindI[1]
            Vji2 += dud_m1 * qiUindJ[1]
            Vij3 += dud_m1 * qiUindI[2]
            Vji3 += dud_m1 * qiUindJ[2]

    if lmax >= 2:
        # C-Q
        Vij0 = Vij0 + cq*qiQI[4]
        Vji4 = cq*qiQJ[0]
        Vij4 = cq*qiQI[0]
        Vji0 = Vji0 + cq*qiQJ[4]
        # D-Q m0
        Vij1 += dq_m0*qiQI[4]
        Vji4 += dq_m0*qiQJ[1] 
        # Q-D m0
        Vij4 -= dq_m0*qiQI[1]
        Vji1 -= dq_m0*qiQJ[4]
        # D-Q m1
        Vij2 = Vij2 + dq_m1*qiQI[5]
        Vji5 = dq_m1*qiQJ[2]
        Vij3 += dq_m1*qiQI[6]
        Vji6 = dq_m1*qiQJ[3]
        Vij5 = -(dq_m1*qiQI[2])
        Vji2 += -(dq_m1*qiQJ[5])
        Vij6 = -(dq_m1*qiQI[3])
        Vji3 += -(dq_m1*qiQJ[6])
        # Q-Q m0
        Vij4 += qq_m0*qiQI[4]
        Vji4 += qq_m0*qiQJ[4] 
        # Q-Q m1
        Vij5 += qq_m1*qiQI[5]
        Vji5 += qq_m1*qiQJ[5]
        Vij6 += qq_m1*qiQI[6]
        Vji6 += qq_m1*qiQJ[6]
        # Q-Q m2
        Vij7  = qq_m2*qiQI[7]
        Vji7  = qq_m2*qiQJ[7]
        Vij8  = qq_m2*qiQI[8]
        Vji8  = qq_m2*qiQJ[8]
        # Q-Uind
        if lpol:
            Vji4 += udq_m0*qiUindJ[0]
            Vij4 -= udq_m0*qiUindI[0]
            Vji5 += udq_m1*qiUindJ[1]
            Vji6 += udq_m1*qiUindJ[2]
            Vij5 -= udq_m1*qiUindI[1]
            Vij6 -= udq_m1*qiUindI[2]

    # Uind - Uind
    if lpol:
        Vij1dd = udud_m0 * qiUindI[0]
        Vji1dd = udud_m0 * qiUindJ[0]
        Vij2dd = udud_m1 * qiUindI[1]
        Vji2dd = udud_m1 * qiUindJ[1]
        Vij3dd = udud_m1 * qiUindI[2]
        Vji3dd = udud_m1 * qiUindJ[2]
        Vijdd = jnp.stack(( Vij1dd, Vij2dd, Vij3dd))
        Vjidd = jnp.stack(( Vji1dd, Vji2dd, Vji3dd))

    if lmax == 0:
        Vij = Vij0
        Vji = Vji0
    elif lmax == 1:
        Vij = jnp.stack((Vij0, Vij1, Vij2, Vij3))
        Vji = jnp.stack((Vji0, Vji1, Vji2, Vji3))
    elif lmax == 2:
        Vij = jnp.stack((Vij0, Vij1, Vij2, Vij3, Vij4, Vij5, Vij6, Vij7, Vij8))
        Vji = jnp.stack((Vji0, Vji1, Vji2, Vji3, Vji4, Vji5, Vji6, Vji7, Vji8))
    else:
        raise ValueError(f"Invalid lmax {lmax}. Valid values are 0, 1, 2")

    if lpol:
        return jnp.array(0.5) * (jnp.sum(qiQJ*Vij) + jnp.sum(qiQI*Vji)) + jnp.array(0.5) * (jnp.sum(qiUindJ*Vijdd) + jnp.sum(qiUindI*Vjidd))
    else:
        return jnp.array(0.5) * (jnp.sum(qiQJ*Vij) + jnp.sum(qiQI*Vji))

pme_self(Q_h, kappa, lmax=2)

This function calculates the PME self energy

Inputs

Q: Na * (lmax+1)^2: harmonic multipoles, local or global does not matter kappa: float: kappa used in PME

Output

ene_self: float: the self energy

Source code in dmff/admp/pme.py
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
@jit_condition(static_argnums=(2))
def pme_self(Q_h, kappa, lmax=2):
    '''
    This function calculates the PME self energy

    Inputs:
        Q:
            Na * (lmax+1)^2: harmonic multipoles, local or global does not matter
        kappa:
            float: kappa used in PME

    Output:
        ene_self:
            float: the self energy
    '''
    n_harms = (lmax + 1) ** 2    
    l_list = np.array([0] + [1,]*3 + [2,]*5)[:n_harms]
    l_fac2 = np.array([1] + [3,]*3 + [15,]*5)[:n_harms]
    factor = kappa/np.sqrt(np.pi) * (2*kappa**2)**l_list / l_fac2
    return - jnp.sum(factor[np.newaxis] * Q_h**2) * DIELECTRIC

pol_penalty(U_ind, pol)

The energy penalty for polarization of each site, currently only supports isotropic polarization:

Inputs

U_ind: Na * 3 float: induced dipoles, in isotropic polarization case, cartesian or harmonic does not matter pol: (Na,) float: polarizability

Source code in dmff/admp/pme.py
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
@jit_condition(static_argnums=())
def pol_penalty(U_ind, pol):
    '''
    The energy penalty for polarization of each site, currently only supports isotropic polarization:

    Inputs:
        U_ind:
            Na * 3 float: induced dipoles, in isotropic polarization case, cartesian or harmonic does not matter
        pol:
            (Na,) float: polarizability
    '''
    # this is to remove the singularity when pol=0
    pol_pi = trim_val_0(pol)
    # pol_pi = pol/(jnp.exp((-pol+1e-08)*1e10)+1) + 1e-08/(jnp.exp((pol-1e-08)*1e10)+1)
    return jnp.sum(0.5/pol_pi*(U_ind**2).T) * DIELECTRIC

setup_ewald_parameters(rc, ethresh, box=None, spacing=None, method='openmm')

Given the cutoff distance, and the required precision, determine the parameters used in Ewald sum, including: kappa, K1, K2, and K3.


float

The cutoff distance, in nm

float

Required energy precision, in kJ/mol

ndarray, optional

3*3 matrix, box size, a, b, c arranged in rows, used in openmm method

float, optional

fourier spacing to determine K, used in gromacs method

str

Method to determine ewald parameters. Valid values: "openmm" or "gromacs". If openmm, the algorithm can refer to http://docs.openmm.org/latest/userguide/theory.html If gromacs, the algorithm is adapted from gromacs source code

Returns

kappa, K1, K2, K3: (float, int, int, int) float, the attenuation factor K1, K2, K3: integers, sizes of the k-points mesh

Source code in dmff/admp/pme.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def setup_ewald_parameters(
    rc: float,
    ethresh: float, 
    box: Optional[jnp.ndarray] = None,
    spacing: Optional[float] = None,
    method: str = 'openmm'
) -> Tuple[float, int, int, int]:
    '''
    Given the cutoff distance, and the required precision, determine the parameters used in
    Ewald sum, including: kappa, K1, K2, and K3.


    Parameters:
    ----------
    rc: float
        The cutoff distance, in nm
    ethresh: float
        Required energy precision, in kJ/mol
    box: ndarray, optional
        3*3 matrix, box size, a, b, c arranged in rows, used in openmm method
    spacing: float, optional
        fourier spacing to determine K, used in gromacs method
    method: str
        Method to determine ewald parameters. Valid values: "openmm" or "gromacs".
        If openmm, the algorithm can refer to http://docs.openmm.org/latest/userguide/theory.html
        If gromacs, the algorithm is adapted from gromacs source code

    Returns
    -------
    kappa, K1, K2, K3: (float, int, int, int)
        float, the attenuation factor
    K1, K2, K3:
        integers, sizes of the k-points mesh
    '''
    if method == "openmm":
        kappa = jnp.sqrt(-jnp.log(2 * ethresh)) / rc
        K1 = jnp.ceil(2 * kappa * box[0, 0] / 3 / ethresh**0.2)
        K2 = jnp.ceil(2 * kappa * box[1, 1] / 3 / ethresh**0.2)
        K3 = jnp.ceil(2 * kappa * box[2, 2] / 3 / ethresh**0.2)

        return kappa, int(K1), int(K2), int(K3)
    elif method == "gromacs":
        # determine kappa
        kappa = 5.0
        i = 0
        while erfc(kappa * rc) > ethresh:
            i += 1
            kappa *= 2

        n = i + 60
        low = 0.0
        high = kappa
        for k in range(n):
            kappa = (low + high) / 2
            if erfc(kappa * rc) > ethresh:
                low = kappa
            else:
                high = kappa
        # determine K
        K1 = int(jnp.ceil(box[0, 0] / spacing))
        K2 = int(jnp.ceil(box[1, 1] / spacing))
        K3 = int(jnp.ceil(box[2, 2] / spacing))
        return kappa, K1, K2, K3 
    else:
        raise ValueError(
            f"Invalid method: {method}."
            "Valid methods: 'openmm', 'gromacs'"
        )

switch_val(x, x0, sigma, y0, y1)

This is a Fermi function switches between y0 and y1, according to the value of x y = y0 when x << x0 y = y1 when x >> x1 sigma control sthe switch width

Source code in dmff/admp/pme.py
455
456
457
458
459
460
461
462
463
464
465
466
@jit_condition(static_argnums=())
def switch_val(x, x0, sigma, y0, y1):
    '''
    This is a Fermi function switches between y0 and y1, according to the value of x
    y = y0 when x << x0
    y = y1 when x >> x1
    sigma control sthe switch width
    '''
    u = (x-x0) / sigma
    w0 = 1 / (jnp.exp(u) + 1)
    w1 = 1 - w0
    return w0*y0 + w1*y1