Skip to content

Commit

Permalink
Wrap NonbondedAllPairs and NonbondedPairList, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcwitt committed Feb 22, 2022
1 parent be7900e commit e97aefc
Show file tree
Hide file tree
Showing 7 changed files with 448 additions and 56 deletions.
56 changes: 56 additions & 0 deletions tests/nonbonded/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import pytest
from simtk.openmm import app

from timemachine.fe.utils import to_md_units
from timemachine.ff.handlers import openmm_deserializer
from timemachine.lib import potentials


@pytest.fixture(autouse=True)
def set_random_seed():
np.random.seed(2022)
yield


@pytest.fixture()
def rng():
return np.random.default_rng(2022)


@pytest.fixture
def example_system():
pdb_path = "tests/data/5dfr_solv_equil.pdb"
host_pdb = app.PDBFile(pdb_path)
ff = app.ForceField("amber99sbildn.xml", "tip3p.xml")
return (
ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False),
host_pdb.positions,
host_pdb.topology.getPeriodicBoxVectors(),
)


@pytest.fixture
def example_nonbonded_params(example_system):
host_system, _, _ = example_system
host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0)

nonbonded_fn = None
for f in host_fns:
if isinstance(f, potentials.Nonbonded):
nonbonded_fn = f

assert nonbonded_fn is not None
return nonbonded_fn.params


@pytest.fixture
def example_conf(example_system):
_, host_conf, _ = example_system
return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf])


@pytest.fixture
def example_box(example_system):
_, _, box = example_system
return np.asarray(box / box.unit)
126 changes: 126 additions & 0 deletions tests/nonbonded/test_nonbonded_all_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import functools

import numpy as np
import pytest
from common import GradientTest

from timemachine.lib.potentials import NonbondedAllPairs, NonbondedAllPairsInterpolated
from timemachine.potentials import nonbonded


def test_nonbonded_all_pairs_invalid_planes_offsets():
with pytest.raises(RuntimeError) as e:
NonbondedAllPairs([0], [0, 0], 2.0, 1.1).unbound_impl(np.float32)

assert "lambda offset idxs and plane idxs need to be equivalent" in str(e)


def test_nonbonded_all_pairs_invalid_num_atoms():
potential = NonbondedAllPairs([0], [0], 2.0, 1.1).unbound_impl(np.float32)
with pytest.raises(RuntimeError) as e:
potential.execute(np.zeros((2, 3)), np.zeros((1, 3)), np.eye(3), 0)

assert "NonbondedAllPairs::execute_device(): expected N == N_, got N=2, N_=1" in str(e)


def test_nonbonded_all_pairs_invalid_num_params():
potential = NonbondedAllPairs([0], [0], 2.0, 1.1).unbound_impl(np.float32)
with pytest.raises(RuntimeError) as e:
potential.execute(np.zeros((1, 3)), np.zeros((2, 3)), np.eye(3), 0)

assert "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=6, M*N_*3=3" in str(e)

potential_interp = NonbondedAllPairsInterpolated([0], [0], 2.0, 1.1).unbound_impl(np.float32)
with pytest.raises(RuntimeError) as e:
potential_interp.execute(np.zeros((1, 3)), np.zeros((1, 3)), np.eye(3), 0)

assert "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=3, M*N_*3=6" in str(e)


def make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff):
@functools.wraps(nonbonded.nonbonded_v3)
def wrapped(conf, params, box, lamb):
num_atoms, _ = conf
no_rescale = np.ones((num_atoms, num_atoms))
return nonbonded.nonbonded_v3(
conf,
params,
box,
lamb,
charge_rescale_mask=no_rescale,
lj_rescale_mask=no_rescale,
beta=beta,
cutoff=cutoff,
lambda_plane_idxs=lambda_plane_idxs,
lambda_offset_idxs=lambda_offset_idxs,
)

return wrapped


@pytest.mark.parametrize("lamb", [0.0, 0.1])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms", [33, 65, 231, 1050, 4080])
def test_nonbonded_all_pairs_correctness(
num_atoms,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng: np.random.Generator,
):
conf = example_conf[:num_atoms]
params = example_nonbonded_params[:num_atoms, :]

lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)
lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)

ref_potential = make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)
test_potential = NonbondedAllPairs(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)

GradientTest().compare_forces(
conf, params, example_box, lamb, ref_potential, test_potential, precision=precision, rtol=rtol, atol=atol
)


@pytest.mark.parametrize("lamb", [0.0, 0.1, 0.9, 1.0])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms", [33, 231, 4080])
def test_nonbonded_all_pairs_interpolated_correctness(
num_atoms,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng: np.random.Generator,
):
"Compares with jax reference implementation."

conf = example_conf[:num_atoms]
params_initial = example_nonbonded_params[:num_atoms, :]
params_final = params_initial + rng.normal(0, 0.01, size=params_initial.shape)
params = np.concatenate((params_initial, params_final))

lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)
lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)

ref_potential = nonbonded.interpolated(make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff))
test_potential = NonbondedAllPairsInterpolated(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)

GradientTest().compare_forces(
conf, params, example_box, lamb, ref_potential, test_potential, precision=precision, rtol=rtol, atol=atol
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,64 +5,11 @@
import numpy as np
import pytest
from common import GradientTest, prepare_reference_nonbonded
from simtk.openmm import app

from timemachine.fe.utils import to_md_units
from timemachine.ff.handlers import openmm_deserializer
from timemachine.lib import potentials
from timemachine.lib.potentials import NonbondedInteractionGroup, NonbondedInteractionGroupInterpolated
from timemachine.potentials import jax_utils, nonbonded


@pytest.fixture(autouse=True)
def set_random_seed():
np.random.seed(2022)
yield


@pytest.fixture()
def rng():
return np.random.default_rng(2022)


@pytest.fixture
def example_system():
pdb_path = "tests/data/5dfr_solv_equil.pdb"
host_pdb = app.PDBFile(pdb_path)
ff = app.ForceField("amber99sbildn.xml", "tip3p.xml")
return (
ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False),
host_pdb.positions,
host_pdb.topology.getPeriodicBoxVectors(),
)


@pytest.fixture
def example_nonbonded_params(example_system):
host_system, _, _ = example_system
host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0)

nonbonded_fn = None
for f in host_fns:
if isinstance(f, potentials.Nonbonded):
nonbonded_fn = f

assert nonbonded_fn is not None
return nonbonded_fn.params


@pytest.fixture
def example_conf(example_system):
_, host_conf, _ = example_system
return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf])


@pytest.fixture
def example_box(example_system):
_, _, box = example_system
return np.asarray(box / box.unit)


def test_nonbonded_interaction_group_invalid_indices():
def make_potential(ligand_idxs, num_atoms):
lambda_plane_idxs = [0] * num_atoms
Expand Down
102 changes: 102 additions & 0 deletions tests/nonbonded/test_nonbonded_pair_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import jax

jax.config.update("jax_enable_x64", True)

import functools

import numpy as np
import pytest
from common import GradientTest

from timemachine.lib.potentials import NonbondedPairList
from timemachine.potentials import jax_utils, nonbonded


def test_nonbonded_pair_list_invalid_pair_idxs():
with pytest.raises(RuntimeError) as e:
NonbondedPairList([0], [0], [0], [0], 2.0, 1.1).unbound_impl(np.float32)

assert "pair_idxs.size() must be even, but got 1" in str(e)

with pytest.raises(RuntimeError) as e:
NonbondedPairList([(0, 0)], [(1, 1)], [0], [0], 2.0, 1.1).unbound_impl(np.float32)

assert "illegal pair with src == dst: 0, 0" in str(e)

with pytest.raises(RuntimeError) as e:
NonbondedPairList([(0, 1)], [(1, 1), (2, 2)], [0], [0], 2.0, 1.1).unbound_impl(np.float32)

assert "expected same number of pairs and scale tuples, but got 1 != 2" in str(e)


def make_ref_potential(pair_idxs, scales, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff):
@functools.wraps(nonbonded.nonbonded_v3_on_specific_pairs)
def wrapped(conf, params, box, lamb):

# compute 4d coordinates
w = jax_utils.compute_lifting_parameter(lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff)
conf_4d = jax_utils.augment_dim(conf, w)
box_4d = (1000 * jax.numpy.eye(4)).at[:3, :3].set(box)

vdW, electrostatics = nonbonded.nonbonded_v3_on_specific_pairs(
conf_4d, params, box_4d, pair_idxs[:, 0], pair_idxs[:, 1], beta, cutoff
)
return jax.numpy.sum(scales[:, 1] * vdW + scales[:, 0] * electrostatics)

return wrapped


@pytest.mark.parametrize("lamb", [0.0, 0.1])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms", [4080])
@pytest.mark.parametrize("num_atoms_interacting", [1, 30, 1000])
def test_nonbonded_interaction_group_correctness(
num_atoms_interacting,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng: np.random.Generator,
):
"Compares with jax reference implementation."

num_atoms, _ = example_conf.shape

atom_idxs = rng.choice(
num_atoms,
size=(
2,
num_atoms_interacting,
),
replace=False,
).astype(np.int32)

pair_idxs = np.stack(np.meshgrid(atom_idxs[0, :], atom_idxs[1, :])).reshape(2, -1).T
num_pairs, _ = pair_idxs.shape

scales = rng.uniform(0, 1, size=(num_pairs, 2))

lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)
lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32)

ref_potential = make_ref_potential(pair_idxs, scales, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)
test_potential = NonbondedPairList(pair_idxs, scales, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)

GradientTest().compare_forces(
example_conf,
example_nonbonded_params,
example_box,
lamb,
ref_potential,
test_potential,
precision=precision,
rtol=rtol,
atol=atol,
)
9 changes: 6 additions & 3 deletions timemachine/cpp/src/nonbonded_pair_list.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,22 @@ NonbondedPairList<RealType, Negated, Interpolated>::NonbondedPairList(
kernel_cache_.program(kernel_src.c_str()).kernel("k_add_du_dp_interpolated").instantiate()) {

if (pair_idxs.size() % 2 != 0) {
throw std::runtime_error("pair_idxs.size() must be exactly 2*M");
throw std::runtime_error("pair_idxs.size() must be even, but got " + std::to_string(pair_idxs.size()));
}

for (int i = 0; i < M_; i++) {
auto src = pair_idxs[i * 2 + 0];
auto dst = pair_idxs[i * 2 + 1];
if (src == dst) {
throw std::runtime_error("illegal pair with src == dst");
throw std::runtime_error(
"illegal pair with src == dst: " + std::to_string(src) + ", " + std::to_string(dst));
}
}

if (scales.size() / 2 != M_) {
throw std::runtime_error("bad scales size!");
throw std::runtime_error(
"expected same number of pairs and scale tuples, but got " + std::to_string(M_) +
" != " + std::to_string(scales.size() / 2));
}

gpuErrchk(cudaMalloc(&d_pair_idxs_, M_ * 2 * sizeof(*d_pair_idxs_)));
Expand Down
Loading

0 comments on commit e97aefc

Please sign in to comment.