Skip to content

Commit

Permalink
Wrap NonbondedAllPairs, NonbondedPairList (#643)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcwitt authored Feb 24, 2022
1 parent 85f0503 commit c400580
Show file tree
Hide file tree
Showing 8 changed files with 550 additions and 73 deletions.
71 changes: 71 additions & 0 deletions tests/nonbonded/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pytest
from simtk.openmm import app

from timemachine.fe.utils import to_md_units
from timemachine.ff.handlers import openmm_deserializer
from timemachine.lib import potentials

# NOTE: For efficiency, we use module-scoped fixtures for expensive
# setup. To prevent unintended mutation, these shouldn't be used in
# tests directly. Instead, they are wrapped below by function-scoped
# fixtures that return copies of the data.


@pytest.fixture(scope="module")
def _example_system():
pdb_path = "tests/data/5dfr_solv_equil.pdb"
host_pdb = app.PDBFile(pdb_path)
ff = app.ForceField("amber99sbildn.xml", "tip3p.xml")
return (
ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False),
host_pdb.positions,
host_pdb.topology.getPeriodicBoxVectors(),
)


@pytest.fixture(scope="module")
def _example_nonbonded_params(_example_system):
host_system, _, _ = _example_system
host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0)

nonbonded_fn = None
for f in host_fns:
if isinstance(f, potentials.Nonbonded):
nonbonded_fn = f

assert nonbonded_fn is not None
return nonbonded_fn.params


@pytest.fixture(scope="module")
def _example_conf(_example_system):
_, host_conf, _ = _example_system
return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf])


@pytest.fixture(scope="function", autouse=True)
def set_random_seed():
np.random.seed(2022)
yield


@pytest.fixture(scope="function")
def rng():
return np.random.default_rng(2022)


@pytest.fixture(scope="function")
def example_nonbonded_params(_example_nonbonded_params):
return _example_nonbonded_params[:]


@pytest.fixture(scope="function")
def example_conf(_example_conf):
return _example_conf[:]


@pytest.fixture(scope="function")
def example_box(_example_system):
_, _, box = _example_system
return np.asarray(box / box.unit)
21 changes: 21 additions & 0 deletions tests/nonbonded/parameter_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
import numpy.typing as npt


def gen_params(params_initial: npt.NDArray, rng: np.random.Generator, dcharge=0.01, dlogsig=0.1, dlogeps=0.1):
"""Given an initial set of nonbonded parameters, generate random
final parameters and return the concatenation of the initial and
final parameters"""

num_atoms, _ = params_initial.shape
charge_init, sig_init, eps_init = params_initial[:].T

charge_final = charge_init + rng.normal(0, dcharge, size=(num_atoms,))

# perturb LJ parameters in log space to avoid negative result
sig_final = np.where(sig_init, np.exp(np.log(sig_init) + rng.normal(0, dlogsig, size=(num_atoms,))), 0)
eps_final = np.where(eps_init, np.exp(np.log(eps_init) + rng.normal(0, dlogeps, size=(num_atoms,))), 0)

params_final = np.stack((charge_final, sig_final, eps_final), axis=1)

return np.concatenate((params_initial, params_final))
128 changes: 128 additions & 0 deletions tests/nonbonded/test_nonbonded_all_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import functools

import numpy as np
import pytest
from common import GradientTest
from parameter_interpolation import gen_params

from timemachine.lib.potentials import NonbondedAllPairs, NonbondedAllPairsInterpolated
from timemachine.potentials import nonbonded


def test_nonbonded_all_pairs_invalid_planes_offsets():
with pytest.raises(RuntimeError) as e:
NonbondedAllPairs([0], [0, 0], 2.0, 1.1).unbound_impl(np.float32)

assert "lambda offset idxs and plane idxs need to be equivalent" in str(e)


def test_nonbonded_all_pairs_invalid_num_atoms():
potential = NonbondedAllPairs([0], [0], 2.0, 1.1).unbound_impl(np.float32)
with pytest.raises(RuntimeError) as e:
potential.execute(np.zeros((2, 3)), np.zeros((1, 3)), np.eye(3), 0)

assert "NonbondedAllPairs::execute_device(): expected N == N_, got N=2, N_=1" in str(e)


def test_nonbonded_all_pairs_invalid_num_params():
potential = NonbondedAllPairs([0], [0], 2.0, 1.1).unbound_impl(np.float32)
with pytest.raises(RuntimeError) as e:
potential.execute(np.zeros((1, 3)), np.zeros((2, 3)), np.eye(3), 0)

assert "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=6, M*N_*3=3" in str(e)

potential_interp = NonbondedAllPairsInterpolated([0], [0], 2.0, 1.1).unbound_impl(np.float32)
with pytest.raises(RuntimeError) as e:
potential_interp.execute(np.zeros((1, 3)), np.zeros((1, 3)), np.eye(3), 0)

assert "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=3, M*N_*3=6" in str(e)


def make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff):
@functools.wraps(nonbonded.nonbonded_v3)
def wrapped(conf, params, box, lamb):
num_atoms, _ = conf.shape
no_rescale = np.ones((num_atoms, num_atoms))
return nonbonded.nonbonded_v3(
conf,
params,
box,
lamb,
charge_rescale_mask=no_rescale,
lj_rescale_mask=no_rescale,
beta=beta,
cutoff=cutoff,
lambda_plane_idxs=lambda_plane_idxs,
lambda_offset_idxs=lambda_offset_idxs,
)

return wrapped


@pytest.mark.parametrize("lamb", [0.0, 0.1])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms", [33, 65, 231, 1050, 4080])
def test_nonbonded_all_pairs_correctness(
num_atoms,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng: np.random.Generator,
):
"Compares with jax reference implementation."

conf = example_conf[:num_atoms]
params = example_nonbonded_params[:num_atoms, :]

lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)
lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)

ref_potential = make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)
test_potential = NonbondedAllPairs(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)

GradientTest().compare_forces(
conf, params, example_box, lamb, ref_potential, test_potential, precision=precision, rtol=rtol, atol=atol
)


@pytest.mark.parametrize("lamb", [0.0, 0.1, 0.9, 1.0])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 2e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms", [33, 231, 4080])
def test_nonbonded_all_pairs_interpolated_correctness(
num_atoms,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng: np.random.Generator,
):
"Compares with jax reference implementation, with parameter interpolation."

conf = example_conf[:num_atoms]
params_initial = example_nonbonded_params[:num_atoms, :]
params = gen_params(params_initial, rng)

lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)
lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)

ref_potential = nonbonded.interpolated(make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff))
test_potential = NonbondedAllPairsInterpolated(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)

GradientTest().compare_forces(
conf, params, example_box, lamb, ref_potential, test_potential, precision=precision, rtol=rtol, atol=atol
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,78 +5,11 @@
import numpy as np
import pytest
from common import GradientTest, prepare_reference_nonbonded
from simtk.openmm import app
from parameter_interpolation import gen_params

from timemachine.fe.utils import to_md_units
from timemachine.ff.handlers import openmm_deserializer
from timemachine.lib import potentials
from timemachine.lib.potentials import NonbondedInteractionGroup, NonbondedInteractionGroupInterpolated
from timemachine.potentials import jax_utils, nonbonded

# NOTE: For efficiency, we use module-scoped fixtures for expensive
# setup. To prevent unintended mutation, these shouldn't be used in
# tests directly. Instead, they are wrapped below by function-scoped
# fixtures that return copies of the data.


@pytest.fixture(scope="module")
def _example_system():
pdb_path = "tests/data/5dfr_solv_equil.pdb"
host_pdb = app.PDBFile(pdb_path)
ff = app.ForceField("amber99sbildn.xml", "tip3p.xml")
return (
ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False),
host_pdb.positions,
host_pdb.topology.getPeriodicBoxVectors(),
)


@pytest.fixture(scope="module")
def _example_nonbonded_params(_example_system):
host_system, _, _ = _example_system
host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0)

nonbonded_fn = None
for f in host_fns:
if isinstance(f, potentials.Nonbonded):
nonbonded_fn = f

assert nonbonded_fn is not None
return nonbonded_fn.params


@pytest.fixture(scope="module")
def _example_conf(_example_system):
_, host_conf, _ = _example_system
return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf])


@pytest.fixture(scope="function", autouse=True)
def set_random_seed():
np.random.seed(2022)
yield


@pytest.fixture(scope="function")
def rng():
return np.random.default_rng(2022)


@pytest.fixture(scope="function")
def example_nonbonded_params(_example_nonbonded_params):
return _example_nonbonded_params[:]


@pytest.fixture(scope="function")
def example_conf(_example_conf):
return _example_conf[:]


@pytest.fixture(scope="function")
def example_box(_example_system):
_, _, box = _example_system
return np.asarray(box / box.unit)


def test_nonbonded_interaction_group_invalid_indices():
def make_potential(ligand_idxs, num_atoms):
Expand Down Expand Up @@ -214,8 +147,7 @@ def test_nonbonded_interaction_group_interpolated_correctness(

conf = example_conf[:num_atoms]
params_initial = example_nonbonded_params[:num_atoms, :]
params_final = params_initial + rng.normal(0, 0.01, size=params_initial.shape)
params = np.concatenate((params_initial, params_final))
params = gen_params(params_initial, rng)

lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)
lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)
Expand Down
Loading

0 comments on commit c400580

Please sign in to comment.