Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap NonbondedAllPairs, NonbondedPairList #643

Merged
merged 11 commits into from
Feb 24, 2022
71 changes: 71 additions & 0 deletions tests/nonbonded/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pytest
from simtk.openmm import app

from timemachine.fe.utils import to_md_units
from timemachine.ff.handlers import openmm_deserializer
from timemachine.lib import potentials

# NOTE: For efficiency, we use module-scoped fixtures for expensive
# setup. To prevent unintended mutation, these shouldn't be used in
# tests directly. Instead, they are wrapped below by function-scoped
# fixtures that return copies of the data.


@pytest.fixture(scope="module")
def _example_system():
pdb_path = "tests/data/5dfr_solv_equil.pdb"
host_pdb = app.PDBFile(pdb_path)
ff = app.ForceField("amber99sbildn.xml", "tip3p.xml")
return (
ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False),
host_pdb.positions,
host_pdb.topology.getPeriodicBoxVectors(),
)


@pytest.fixture(scope="module")
def _example_nonbonded_params(_example_system):
host_system, _, _ = _example_system
host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0)

nonbonded_fn = None
for f in host_fns:
if isinstance(f, potentials.Nonbonded):
nonbonded_fn = f

assert nonbonded_fn is not None
return nonbonded_fn.params


@pytest.fixture(scope="module")
def _example_conf(_example_system):
_, host_conf, _ = _example_system
return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf])


@pytest.fixture(scope="function", autouse=True)
def set_random_seed():
np.random.seed(2022)
yield


@pytest.fixture(scope="function")
def rng():
return np.random.default_rng(2022)


@pytest.fixture(scope="function")
def example_nonbonded_params(_example_nonbonded_params):
return _example_nonbonded_params[:]


@pytest.fixture(scope="function")
def example_conf(_example_conf):
return _example_conf[:]


@pytest.fixture(scope="function")
def example_box(_example_system):
_, _, box = _example_system
return np.asarray(box / box.unit)
128 changes: 128 additions & 0 deletions tests/nonbonded/test_nonbonded_all_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import functools

import numpy as np
import pytest
from common import GradientTest

from timemachine.lib.potentials import NonbondedAllPairs, NonbondedAllPairsInterpolated
from timemachine.potentials import nonbonded


def test_nonbonded_all_pairs_invalid_planes_offsets():
with pytest.raises(RuntimeError) as e:
NonbondedAllPairs([0], [0, 0], 2.0, 1.1).unbound_impl(np.float32)

assert "lambda offset idxs and plane idxs need to be equivalent" in str(e)


def test_nonbonded_all_pairs_invalid_num_atoms():
potential = NonbondedAllPairs([0], [0], 2.0, 1.1).unbound_impl(np.float32)
with pytest.raises(RuntimeError) as e:
potential.execute(np.zeros((2, 3)), np.zeros((1, 3)), np.eye(3), 0)

assert "NonbondedAllPairs::execute_device(): expected N == N_, got N=2, N_=1" in str(e)


def test_nonbonded_all_pairs_invalid_num_params():
potential = NonbondedAllPairs([0], [0], 2.0, 1.1).unbound_impl(np.float32)
with pytest.raises(RuntimeError) as e:
potential.execute(np.zeros((1, 3)), np.zeros((2, 3)), np.eye(3), 0)

assert "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=6, M*N_*3=3" in str(e)

potential_interp = NonbondedAllPairsInterpolated([0], [0], 2.0, 1.1).unbound_impl(np.float32)
with pytest.raises(RuntimeError) as e:
potential_interp.execute(np.zeros((1, 3)), np.zeros((1, 3)), np.eye(3), 0)

assert "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=3, M*N_*3=6" in str(e)


def make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff):
@functools.wraps(nonbonded.nonbonded_v3)
def wrapped(conf, params, box, lamb):
num_atoms, _ = conf.shape
no_rescale = np.ones((num_atoms, num_atoms))
return nonbonded.nonbonded_v3(
conf,
params,
box,
lamb,
charge_rescale_mask=no_rescale,
lj_rescale_mask=no_rescale,
beta=beta,
cutoff=cutoff,
lambda_plane_idxs=lambda_plane_idxs,
lambda_offset_idxs=lambda_offset_idxs,
)

return wrapped


@pytest.mark.parametrize("lamb", [0.0, 0.1])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms", [33, 65, 231, 1050, 4080])
def test_nonbonded_all_pairs_correctness(
num_atoms,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng: np.random.Generator,
):
"Compares with jax reference implementation."

conf = example_conf[:num_atoms]
params = example_nonbonded_params[:num_atoms, :]

lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)
lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)

ref_potential = make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)
test_potential = NonbondedAllPairs(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)

GradientTest().compare_forces(
conf, params, example_box, lamb, ref_potential, test_potential, precision=precision, rtol=rtol, atol=atol
)


@pytest.mark.parametrize("lamb", [0.0, 0.1, 0.9, 1.0])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 2e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms", [33, 231, 4080])
def test_nonbonded_all_pairs_interpolated_correctness(
num_atoms,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng: np.random.Generator,
):
"Compares with jax reference implementation, with parameter interpolation."

conf = example_conf[:num_atoms]
params_initial = example_nonbonded_params[:num_atoms, :]
params_final = params_initial + np.where(params_initial, rng.normal(0, 0.01, size=params_initial.shape), 0)
maxentile marked this conversation as resolved.
Show resolved Hide resolved
params = np.concatenate((params_initial, params_final))

lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)
lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)

ref_potential = nonbonded.interpolated(make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff))
test_potential = NonbondedAllPairsInterpolated(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)

GradientTest().compare_forces(
conf, params, example_box, lamb, ref_potential, test_potential, precision=precision, rtol=rtol, atol=atol
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,78 +5,10 @@
import numpy as np
import pytest
from common import GradientTest, prepare_reference_nonbonded
from simtk.openmm import app

from timemachine.fe.utils import to_md_units
from timemachine.ff.handlers import openmm_deserializer
from timemachine.lib import potentials
from timemachine.lib.potentials import NonbondedInteractionGroup, NonbondedInteractionGroupInterpolated
from timemachine.potentials import jax_utils, nonbonded

# NOTE: For efficiency, we use module-scoped fixtures for expensive
# setup. To prevent unintended mutation, these shouldn't be used in
# tests directly. Instead, they are wrapped below by function-scoped
# fixtures that return copies of the data.


@pytest.fixture(scope="module")
def _example_system():
pdb_path = "tests/data/5dfr_solv_equil.pdb"
host_pdb = app.PDBFile(pdb_path)
ff = app.ForceField("amber99sbildn.xml", "tip3p.xml")
return (
ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False),
host_pdb.positions,
host_pdb.topology.getPeriodicBoxVectors(),
)


@pytest.fixture(scope="module")
def _example_nonbonded_params(_example_system):
host_system, _, _ = _example_system
host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0)

nonbonded_fn = None
for f in host_fns:
if isinstance(f, potentials.Nonbonded):
nonbonded_fn = f

assert nonbonded_fn is not None
return nonbonded_fn.params


@pytest.fixture(scope="module")
def _example_conf(_example_system):
_, host_conf, _ = _example_system
return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf])


@pytest.fixture(scope="function", autouse=True)
def set_random_seed():
np.random.seed(2022)
yield


@pytest.fixture(scope="function")
def rng():
return np.random.default_rng(2022)


@pytest.fixture(scope="function")
def example_nonbonded_params(_example_nonbonded_params):
return _example_nonbonded_params[:]


@pytest.fixture(scope="function")
def example_conf(_example_conf):
return _example_conf[:]


@pytest.fixture(scope="function")
def example_box(_example_system):
_, _, box = _example_system
return np.asarray(box / box.unit)


def test_nonbonded_interaction_group_invalid_indices():
def make_potential(ligand_idxs, num_atoms):
Expand Down
Loading