mk_softcore_fn(sc_alpha, sc_sigma, sc_power=1, sc_r_power=6, if_state_A=True)

Make softcore function

Source code in dmff/classical/fep.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def mk_softcore_fn(
    sc_alpha: float,
    sc_sigma: float,
    sc_power: int = 1,
    sc_r_power: int = 6,
    if_state_A: bool = True,
) -> Callable:
    """
    Make softcore function
    """
    assert sc_r_power == 6, f"sc_r_power must be 6"
    assert sc_power == 1 or sc_power == 2, f"sc_power must be 1 or 2"
    sig_pow = jnp.power(sc_sigma, sc_r_power)

    def softcore_fn(distances, fep_lambda: float):
        dist_pow = jnp.power(distances, sc_r_power)
        lmd = fep_lambda if if_state_A else 1 - fep_lambda
        lmd_pow = jnp.power(lmd, sc_power)
        shift_dist = jnp.power(sc_alpha * sig_pow * lmd_pow + dist_pow, 1 / 6)
        return shift_dist

    return softcore_fn