Skip to content

Commit

Permalink
Expose interpolated potential to Python API, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
mcwitt committed Feb 18, 2022
1 parent dbf4fbe commit 6d199ca
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 11 deletions.
69 changes: 68 additions & 1 deletion tests/test_nonbonded_interaction_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from timemachine.fe.utils import to_md_units
from timemachine.ff.handlers import openmm_deserializer
from timemachine.lib import potentials
from timemachine.lib.potentials import NonbondedInteractionGroup
from timemachine.lib.potentials import NonbondedInteractionGroup, NonbondedInteractionGroupInterpolated
from timemachine.potentials import jax_utils, nonbonded


Expand Down Expand Up @@ -175,6 +175,73 @@ def ref_ixngroups(conf, params, box, lamb):
)


@pytest.mark.parametrize("lamb", [0.0, 0.1, 0.9, 1.0])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms_ligand", [1, 15])
@pytest.mark.parametrize("num_atoms", [33])
def test_nonbonded_interaction_group_interpolated_correctness(
num_atoms,
num_atoms_ligand,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng,
):
"Compares with jax reference implementation, with parameter interpolation."

conf = example_conf[:num_atoms]
params_initial = example_nonbonded_params[:num_atoms, :]
params_final = params_initial + rng.normal(0, 0.01, size=params_initial.shape)
params = np.concatenate((params_initial, params_final))

lambda_plane_idxs = rng.integers(-2, 3, size=num_atoms, dtype=np.int32)
lambda_offset_idxs = rng.integers(-2, 3, size=num_atoms, dtype=np.int32)

ligand_idxs = rng.choice(num_atoms, size=num_atoms_ligand, replace=False).astype(np.int32)
host_idxs = np.setdiff1d(np.arange(num_atoms), ligand_idxs)

@nonbonded.interpolated
def ref_ixngroups(conf, params, box, lamb):

# compute 4d coordinates
w = jax_utils.compute_lifting_parameter(lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff)
conf_4d = jax_utils.augment_dim(conf, w)
box_4d = (1000 * jax.numpy.eye(4)).at[:3, :3].set(box)

vdW, electrostatics, _ = nonbonded.nonbonded_v3_interaction_groups(
conf_4d, params, box_4d, ligand_idxs, host_idxs, beta, cutoff
)
return jax.numpy.sum(vdW + electrostatics)

test_ixngroups = NonbondedInteractionGroupInterpolated(
ligand_idxs,
lambda_plane_idxs,
lambda_offset_idxs,
beta,
cutoff,
)

GradientTest().compare_forces(
conf,
params,
example_box,
lamb=lamb,
ref_potential=ref_ixngroups,
test_potential=test_ixngroups,
rtol=rtol,
atol=atol,
precision=precision,
)


@pytest.mark.parametrize("lamb", [0.0, 0.1])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
Expand Down
12 changes: 2 additions & 10 deletions tests/test_parameter_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,7 @@
from common import GradientTest, prepare_water_system

from timemachine.lib import potentials


def interpolated_potential(conf, params, box, lamb, u_fn):
assert params.size % 2 == 0

CP = params.shape[0] // 2
new_params = (1 - lamb) * params[:CP] + lamb * params[CP:]

return u_fn(conf, new_params, box, lamb)
from timemachine.potentials import nonbonded


class TestInterpolatedPotential(GradientTest):
Expand Down Expand Up @@ -57,7 +49,7 @@ def test_nonbonded(self):

print("lambda", lamb, "cutoff", cutoff, "precision", precision, "xshape", coords.shape)

ref_interpolated_potential = functools.partial(interpolated_potential, u_fn=ref_potential)
ref_interpolated_potential = nonbonded.interpolated(ref_potential)

test_interpolated_potential = potentials.NonbondedInterpolated(*test_potential.args)

Expand Down
13 changes: 13 additions & 0 deletions timemachine/lib/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,16 @@ def unbound_impl(self, precision):

class NonbondedInteractionGroup(CustomOpWrapper):
pass


class NonbondedInteractionGroupInterpolated(NonbondedInteractionGroup):
def unbound_impl(self, precision):
cls_name_base = "NonbondedInteractionGroup"
if precision == np.float64:
cls_name_base += "_f64_interpolated"
else:
cls_name_base += "_f32_interpolated"

custom_ctor = getattr(custom_ops, cls_name_base)

return custom_ctor(*self.args)
18 changes: 18 additions & 0 deletions timemachine/potentials/nonbonded.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

import jax.numpy as np
from jax.ops import index, index_update
from jax.scipy.special import erfc
Expand Down Expand Up @@ -262,6 +264,22 @@ def nonbonded_v3_interaction_groups(conf, params, box, inds_l, inds_r, beta: flo
return vdW, electrostatics, pairs


def interpolated(u_fn):
@functools.wraps(u_fn)
def wrapper(conf, params, box, lamb):

# params is expected to be the concatenation of initial
# (lambda = 0) and final (lamda = 1) parameters, each of
# length num_atoms
assert params.size % 2 == 0
num_atoms = params.shape[0] // 2

new_params = (1 - lamb) * params[:num_atoms] + lamb * params[num_atoms:]
return u_fn(conf, new_params, box, lamb)

return wrapper


def validate_coulomb_cutoff(cutoff=1.0, beta=2.0, threshold=1e-2):
"""check whether f(r) = erfc(beta * r) <= threshold at r = cutoff
following https://github.com/proteneer/timemachine/pull/424#discussion_r629678467"""
Expand Down

0 comments on commit 6d199ca

Please sign in to comment.