From 6d199ca4e21366d4c84d0ea81f87f668368efc94 Mon Sep 17 00:00:00 2001 From: Matt Wittmann Date: Fri, 18 Feb 2022 11:39:52 -0800 Subject: [PATCH] Expose interpolated potential to Python API, add test --- tests/test_nonbonded_interaction_group.py | 69 ++++++++++++++++++++++- tests/test_parameter_interpolation.py | 12 +--- timemachine/lib/potentials.py | 13 +++++ timemachine/potentials/nonbonded.py | 18 ++++++ 4 files changed, 101 insertions(+), 11 deletions(-) diff --git a/tests/test_nonbonded_interaction_group.py b/tests/test_nonbonded_interaction_group.py index 0a6d5d6be..18574c1b6 100644 --- a/tests/test_nonbonded_interaction_group.py +++ b/tests/test_nonbonded_interaction_group.py @@ -10,7 +10,7 @@ from timemachine.fe.utils import to_md_units from timemachine.ff.handlers import openmm_deserializer from timemachine.lib import potentials -from timemachine.lib.potentials import NonbondedInteractionGroup +from timemachine.lib.potentials import NonbondedInteractionGroup, NonbondedInteractionGroupInterpolated from timemachine.potentials import jax_utils, nonbonded @@ -175,6 +175,73 @@ def ref_ixngroups(conf, params, box, lamb): ) +@pytest.mark.parametrize("lamb", [0.0, 0.1, 0.9, 1.0]) +@pytest.mark.parametrize("beta", [2.0]) +@pytest.mark.parametrize("cutoff", [1.1]) +@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)]) +@pytest.mark.parametrize("num_atoms_ligand", [1, 15]) +@pytest.mark.parametrize("num_atoms", [33]) +def test_nonbonded_interaction_group_interpolated_correctness( + num_atoms, + num_atoms_ligand, + precision, + rtol, + atol, + cutoff, + beta, + lamb, + example_nonbonded_params, + example_conf, + example_box, + rng, +): + "Compares with jax reference implementation, with parameter interpolation." + + conf = example_conf[:num_atoms] + params_initial = example_nonbonded_params[:num_atoms, :] + params_final = params_initial + rng.normal(0, 0.01, size=params_initial.shape) + params = np.concatenate((params_initial, params_final)) + + lambda_plane_idxs = rng.integers(-2, 3, size=num_atoms, dtype=np.int32) + lambda_offset_idxs = rng.integers(-2, 3, size=num_atoms, dtype=np.int32) + + ligand_idxs = rng.choice(num_atoms, size=num_atoms_ligand, replace=False).astype(np.int32) + host_idxs = np.setdiff1d(np.arange(num_atoms), ligand_idxs) + + @nonbonded.interpolated + def ref_ixngroups(conf, params, box, lamb): + + # compute 4d coordinates + w = jax_utils.compute_lifting_parameter(lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) + conf_4d = jax_utils.augment_dim(conf, w) + box_4d = (1000 * jax.numpy.eye(4)).at[:3, :3].set(box) + + vdW, electrostatics, _ = nonbonded.nonbonded_v3_interaction_groups( + conf_4d, params, box_4d, ligand_idxs, host_idxs, beta, cutoff + ) + return jax.numpy.sum(vdW + electrostatics) + + test_ixngroups = NonbondedInteractionGroupInterpolated( + ligand_idxs, + lambda_plane_idxs, + lambda_offset_idxs, + beta, + cutoff, + ) + + GradientTest().compare_forces( + conf, + params, + example_box, + lamb=lamb, + ref_potential=ref_ixngroups, + test_potential=test_ixngroups, + rtol=rtol, + atol=atol, + precision=precision, + ) + + @pytest.mark.parametrize("lamb", [0.0, 0.1]) @pytest.mark.parametrize("beta", [2.0]) @pytest.mark.parametrize("cutoff", [1.1]) diff --git a/tests/test_parameter_interpolation.py b/tests/test_parameter_interpolation.py index 099a5fd5d..9299b05da 100644 --- a/tests/test_parameter_interpolation.py +++ b/tests/test_parameter_interpolation.py @@ -9,15 +9,7 @@ from common import GradientTest, prepare_water_system from timemachine.lib import potentials - - -def interpolated_potential(conf, params, box, lamb, u_fn): - assert params.size % 2 == 0 - - CP = params.shape[0] // 2 - new_params = (1 - lamb) * params[:CP] + lamb * params[CP:] - - return u_fn(conf, new_params, box, lamb) +from timemachine.potentials import nonbonded class TestInterpolatedPotential(GradientTest): @@ -57,7 +49,7 @@ def test_nonbonded(self): print("lambda", lamb, "cutoff", cutoff, "precision", precision, "xshape", coords.shape) - ref_interpolated_potential = functools.partial(interpolated_potential, u_fn=ref_potential) + ref_interpolated_potential = nonbonded.interpolated(ref_potential) test_interpolated_potential = potentials.NonbondedInterpolated(*test_potential.args) diff --git a/timemachine/lib/potentials.py b/timemachine/lib/potentials.py index 23720f3fb..26346a5d8 100644 --- a/timemachine/lib/potentials.py +++ b/timemachine/lib/potentials.py @@ -282,3 +282,16 @@ def unbound_impl(self, precision): class NonbondedInteractionGroup(CustomOpWrapper): pass + + +class NonbondedInteractionGroupInterpolated(NonbondedInteractionGroup): + def unbound_impl(self, precision): + cls_name_base = "NonbondedInteractionGroup" + if precision == np.float64: + cls_name_base += "_f64_interpolated" + else: + cls_name_base += "_f32_interpolated" + + custom_ctor = getattr(custom_ops, cls_name_base) + + return custom_ctor(*self.args) diff --git a/timemachine/potentials/nonbonded.py b/timemachine/potentials/nonbonded.py index 006351a52..14de0136d 100644 --- a/timemachine/potentials/nonbonded.py +++ b/timemachine/potentials/nonbonded.py @@ -1,3 +1,5 @@ +import functools + import jax.numpy as np from jax.ops import index, index_update from jax.scipy.special import erfc @@ -262,6 +264,22 @@ def nonbonded_v3_interaction_groups(conf, params, box, inds_l, inds_r, beta: flo return vdW, electrostatics, pairs +def interpolated(u_fn): + @functools.wraps(u_fn) + def wrapper(conf, params, box, lamb): + + # params is expected to be the concatenation of initial + # (lambda = 0) and final (lamda = 1) parameters, each of + # length num_atoms + assert params.size % 2 == 0 + num_atoms = params.shape[0] // 2 + + new_params = (1 - lamb) * params[:num_atoms] + lamb * params[num_atoms:] + return u_fn(conf, new_params, box, lamb) + + return wrapper + + def validate_coulomb_cutoff(cutoff=1.0, beta=2.0, threshold=1e-2): """check whether f(r) = erfc(beta * r) <= threshold at r = cutoff following https://github.com/proteneer/timemachine/pull/424#discussion_r629678467"""