Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NonbondedInteractionGroup potential, rename existing nonbonded potentials #578

Merged
merged 52 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
fe0fe9b
Rename NonbondedDense -> NonbondedAllPairs
mcwitt Jan 24, 2022
5848e88
Rename NonbondedPairs -> NonbondedPairList
mcwitt Jan 24, 2022
5bf391e
Copy NonbondedInteractionGroup from NonbondedAllPairs
mcwitt Jan 24, 2022
b915a4e
Update interface
mcwitt Feb 15, 2022
3b93f01
Copy indices to device
mcwitt Feb 15, 2022
0c06115
Add wrapper for NonbondedInteractionGroup
mcwitt Feb 9, 2022
6774f28
Add test for raising on invalid indices
mcwitt Feb 9, 2022
71c96e3
Check for empty index set
mcwitt Feb 15, 2022
980185e
Add basic correctness test
mcwitt Feb 9, 2022
945e62a
Implement potential
mcwitt Feb 15, 2022
b532e89
Skip kernel invocation if there are no interactions
mcwitt Feb 13, 2022
0f24785
Add offset to row atom indices in case NR != 0
mcwitt Feb 14, 2022
ecf76d8
Avoid "test_" prefix in fixtures
mcwitt Feb 16, 2022
fcd4199
Test with random non-contiguous row atom indices
mcwitt Feb 16, 2022
4f4bc51
Remove obsolete comments
mcwitt Feb 16, 2022
6185c4b
Define N as const
mcwitt Feb 16, 2022
efa4828
Move input validation to top
mcwitt Feb 16, 2022
3b3ea55
Remove guard
mcwitt Feb 16, 2022
d0f154a
Add test for case with zero interactions
mcwitt Feb 16, 2022
3eb51ce
Fix typo
mcwitt Feb 16, 2022
0cd9168
Return early if neighborlist is empty
mcwitt Feb 16, 2022
05b1b35
Attempt to clarify condition
mcwitt Feb 16, 2022
43fd522
Move stray comment, stream synchronization
mcwitt Feb 16, 2022
653fe4d
Remove print statements, make errors more informative
mcwitt Feb 16, 2022
a70a330
Add comments for clarity, remove extra placeholders
mcwitt Feb 16, 2022
690a57c
Remove obsolete comment
mcwitt Feb 16, 2022
1097ca4
Remove obsolete comment
mcwitt Feb 16, 2022
d728988
Add description to test
mcwitt Feb 16, 2022
eb07db3
Add test comparing with nonbonded_v3_on_specific_pairs
mcwitt Feb 16, 2022
3e36b51
Improve accuracy of comments
mcwitt Feb 16, 2022
2dfb18d
Rename test to something less confusing, add type annotation
mcwitt Feb 16, 2022
39907d6
Account for difference with zero LJ parameters in test
mcwitt Feb 17, 2022
561ce2f
Move jax reference potential out of test
mcwitt Feb 17, 2022
dd543ea
Test with multiple lambda values
mcwitt Feb 17, 2022
25fd4f6
Move cudaStreamSynchronize after cudaMemcpyAsync
mcwitt Feb 17, 2022
860dc9d
Update reference to only compute vdW force when eps_ij != 0
mcwitt Feb 17, 2022
b7ebc26
Rename coords -> conf
mcwitt Feb 17, 2022
89ee22f
Reorder tests
mcwitt Feb 17, 2022
a6486c1
Add consistency test applying constant shift to one group
mcwitt Feb 17, 2022
6163d3c
Rename row/col -> ligand/host in test
mcwitt Feb 17, 2022
824342d
Clean up docstrings
mcwitt Feb 17, 2022
a986a34
Adjust test tolerances
mcwitt Feb 17, 2022
91d1722
Replace itertools product with meshgrid
mcwitt Feb 18, 2022
efb100d
Remove unused import
mcwitt Feb 18, 2022
90d0cd5
Remove obsolete comment
mcwitt Feb 18, 2022
695f2c1
Remove no-op checks
mcwitt Feb 18, 2022
6c5598b
Use nonzero offset indices in consistency checks
mcwitt Feb 18, 2022
dbf4fbe
Use nonzero planes, offsets in correctness test
mcwitt Feb 18, 2022
6d199ca
Expose interpolated potential to Python API, add test
mcwitt Feb 18, 2022
22f9dc9
Remove unused import
mcwitt Feb 18, 2022
cdd37ff
Pass tuples to silence type warnings
mcwitt Feb 18, 2022
6d3ce6c
Merge branch 'master' into nonbonded-interaction-groups
mcwitt Feb 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 325 additions & 0 deletions tests/test_nonbonded_interaction_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
import jax

jax.config.update("jax_enable_x64", True)

import numpy as np
import pytest
from common import GradientTest, prepare_reference_nonbonded
from simtk.openmm import app

from timemachine.fe.utils import to_md_units
from timemachine.ff.handlers import openmm_deserializer
from timemachine.lib import potentials
from timemachine.lib.potentials import NonbondedInteractionGroup
from timemachine.potentials import nonbonded


@pytest.fixture(autouse=True)
def set_random_seed():
np.random.seed(2022)
yield


@pytest.fixture()
def rng():
return np.random.default_rng(2022)


@pytest.fixture
def example_system():
pdb_path = "tests/data/5dfr_solv_equil.pdb"
host_pdb = app.PDBFile(pdb_path)
ff = app.ForceField("amber99sbildn.xml", "tip3p.xml")
return (
ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False),
host_pdb.positions,
host_pdb.topology.getPeriodicBoxVectors(),
)


@pytest.fixture
def example_nonbonded_params(example_system):
host_system, _, _ = example_system
host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0)

nonbonded_fn = None
for f in host_fns:
if isinstance(f, potentials.Nonbonded):
nonbonded_fn = f

assert nonbonded_fn is not None
return nonbonded_fn.params


@pytest.fixture
def example_conf(example_system):
_, host_conf, _ = example_system
return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf])


@pytest.fixture
def example_box(example_system):
_, _, box = example_system
return np.asarray(box / box.unit)


def test_nonbonded_interaction_group_invalid_indices():
def make_potential(ligand_idxs, num_atoms):
lambda_plane_idxs = [0] * num_atoms
lambda_offset_idxs = [0] * num_atoms
return NonbondedInteractionGroup(ligand_idxs, lambda_plane_idxs, lambda_offset_idxs, 1.0, 1.0).unbound_impl(
np.float64
)

with pytest.raises(RuntimeError) as e:
make_potential([], 1)
assert "row_atom_idxs must be nonempty" in str(e)

with pytest.raises(RuntimeError) as e:
make_potential([1, 1], 3)
assert "atom indices must be unique" in str(e)


def test_nonbonded_interaction_group_zero_interactions(rng: np.random.Generator):
num_atoms = 33
num_atoms_ligand = 15
beta = 2.0
lamb = 0.1
cutoff = 1.1
box = 10.0 * np.eye(3)
conf = rng.uniform(0, 1, size=(num_atoms, 3))
ligand_idxs = rng.choice(num_atoms, size=num_atoms_ligand, replace=False).astype(np.int32)

# shift ligand atoms in x by twice the cutoff
conf[ligand_idxs, 0] += 2 * cutoff

params = rng.uniform(0, 1, size=(num_atoms, 3))

potential = NonbondedInteractionGroup(
ligand_idxs,
np.zeros(num_atoms, dtype=np.int32),
np.zeros(num_atoms, dtype=np.int32),
beta,
cutoff,
)

du_dx, du_dp, du_dl, u = potential.unbound_impl(np.float64).execute(conf, params, box, lamb)

assert (du_dx == 0).all()
assert (du_dp == 0).all()
assert du_dl == 0
assert u == 0


@pytest.mark.parametrize("lamb", [0.0, 0.1])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms_ligand", [1, 15])
@pytest.mark.parametrize("num_atoms", [33, 231])
def test_nonbonded_interaction_group_correctness(
maxentile marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests here look great to me now!

num_atoms,
num_atoms_ligand,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng,
):
"Compares with jax reference implementation."

conf = example_conf[:num_atoms]
params = example_nonbonded_params[:num_atoms, :]

ligand_idxs = rng.choice(num_atoms, size=num_atoms_ligand, replace=False).astype(np.int32)
host_idxs = np.setdiff1d(np.arange(num_atoms), ligand_idxs)

def ref_ixngroups(conf, params, box, _):
vdW, electrostatics, _ = nonbonded.nonbonded_v3_interaction_groups(
conf, params, box, ligand_idxs, host_idxs, beta, cutoff
)
return jax.numpy.sum(vdW + electrostatics)

test_ixngroups = NonbondedInteractionGroup(
ligand_idxs,
np.zeros(num_atoms, dtype=np.int32),
np.zeros(num_atoms, dtype=np.int32),
beta,
cutoff,
)

GradientTest().compare_forces(
conf,
params,
example_box,
lamb=lamb,
ref_potential=ref_ixngroups,
test_potential=test_ixngroups,
rtol=rtol,
atol=atol,
precision=precision,
)


@pytest.mark.parametrize("lamb", [0.0, 0.1])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms_ligand", [1, 15])
@pytest.mark.parametrize("num_atoms", [33, 231, 1050])
def test_nonbonded_interaction_group_consistency_allpairs_lambda_planes(
num_atoms,
num_atoms_ligand,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng: np.random.Generator,
):
"""Compares with reference nonbonded_v3 potential, which computes
the sum of all pairwise interactions. This uses the identity

U = U_A + U_B + U_AB

where
- U is the all-pairs potential over all atoms
- U_A, U_B are all-pairs potentials for interacting groups A and
B, respectively
- U_AB is the "interaction group" potential, i.e. the sum of
pairwise interactions (a, b) where "a" is in A and "b" is in B

U is computed using the reference potential over all atoms, and
U_A + U_B computed using the reference potential over all atoms
separated into 2 lambda planes according to which interacting
group they belong
"""

conf = example_conf[:num_atoms]
params = example_nonbonded_params[:num_atoms, :]
lambda_offset_idxs = np.zeros(num_atoms, dtype=np.int32)

def make_reference_nonbonded(lambda_plane_idxs):
return prepare_reference_nonbonded(
params=params,
exclusion_idxs=np.array([], dtype=np.int32),
scales=np.zeros((0, 2), dtype=np.float64),
lambda_plane_idxs=lambda_plane_idxs,
lambda_offset_idxs=lambda_offset_idxs,
beta=beta,
cutoff=cutoff,
)

ref_allpairs = make_reference_nonbonded(np.zeros(num_atoms, dtype=np.int32))

ligand_idxs = rng.choice(num_atoms, size=num_atoms_ligand, replace=False).astype(np.int32)
lambda_plane_idxs = np.zeros(num_atoms, dtype=np.int32)
lambda_plane_idxs[ligand_idxs] = 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that most of the alchemical changes (and their contributions to the delta Us) will be coming from the interaction group, we should expand the coverage to include typical ways that the interaction group may depend on lambda.

a) Currently the offset_idxs are all zero for the tests (only plane_idxs are modified in _lambda_planes()). The dependence on lambda isn't fully tested here. Because the w coordinates are:

      w = lambda_plane_idx*cutoff + lambda*lambda_offset_idx
      du/dl = du/dw.dw/dl is only properly tested if lambda_offset_idx != 0

b) Need more test coverage for the InterpolatedNonbondedInteractionGroup (esp in conjunction with varying lambda_plane_idxs and lambda_offset_idxs etc.). I don't think I see any right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the offset_idxs are all zero for the tests (only plane_idxs are modified in _lambda_planes())

Ah, makes sense. Updated to test with nonzero plane and offset idxs in 6c5598b and dbf4fbe.

Need more test coverage for the InterpolatedNonbondedInteractionGroup

Good call, this was untested. Added a Python wrapper and test in 6d199ca


ref_allpairs_minus_ixngroups = make_reference_nonbonded(lambda_plane_idxs)

def ref_ixngroups(*args):
return ref_allpairs(*args) - ref_allpairs_minus_ixngroups(*args)

test_ixngroups = NonbondedInteractionGroup(
ligand_idxs,
np.zeros(num_atoms, dtype=np.int32), # lambda plane indices
lambda_offset_idxs,
beta,
cutoff,
)

GradientTest().compare_forces(
conf,
params,
example_box,
lamb=lamb,
ref_potential=ref_ixngroups,
test_potential=test_ixngroups,
rtol=rtol,
atol=atol,
precision=precision,
)


@pytest.mark.parametrize("lamb", [0.0, 0.1])
@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_atoms_ligand", [1, 15])
@pytest.mark.parametrize("num_atoms", [33, 231])
def test_nonbonded_interaction_group_consistency_allpairs_constant_shift(
num_atoms,
num_atoms_ligand,
precision,
rtol,
atol,
cutoff,
beta,
lamb,
example_nonbonded_params,
example_conf,
example_box,
rng: np.random.Generator,
):
"""Compares with reference nonbonded_v3 potential, which computes
the sum of all pairwise interactions. This uses the identity

U(x') - U(x) = U_AB(x') - U_AB(x)

where
- U is the all-pairs potential over all atoms
- U_A, U_B are all-pairs potentials for interacting groups A and
B, respectively
- U_AB is the "interaction group" potential, i.e. the sum of
pairwise interactions (a, b) where "a" is in A and "b" is in B
- the transformation x -> x' does not affect U_A or U_B (e.g. a
constant translation applied to each atom in one group)
"""

conf = example_conf[:num_atoms]
params = example_nonbonded_params[:num_atoms, :]

def ref_allpairs(conf):
return prepare_reference_nonbonded(
params=params,
exclusion_idxs=np.array([], dtype=np.int32),
scales=np.zeros((0, 2), dtype=np.float64),
lambda_plane_idxs=np.zeros(num_atoms, dtype=np.int32),
lambda_offset_idxs=np.zeros(num_atoms, dtype=np.int32),
beta=beta,
cutoff=cutoff,
)(conf, params, example_box, lamb)

ligand_idxs = rng.choice(num_atoms, size=num_atoms_ligand, replace=False).astype(np.int32)

def test_ixngroups(conf):
_, _, _, u = (
NonbondedInteractionGroup(
ligand_idxs,
np.zeros(num_atoms, dtype=np.int32),
np.zeros(num_atoms, dtype=np.int32),
beta,
cutoff,
)
.unbound_impl(precision)
.execute(conf, params, example_box, lamb)
)
return u

conf_prime = np.array(conf)
conf_prime[ligand_idxs] += rng.normal(0, 0.01, size=(3,))

ref_delta = ref_allpairs(conf_prime) - ref_allpairs(conf)
test_delta = test_ixngroups(conf_prime) - test_ixngroups(conf)

np.testing.assert_allclose(ref_delta, test_delta, rtol=rtol, atol=atol)
6 changes: 4 additions & 2 deletions timemachine/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS
src/gpu_utils.cu
src/vendored/hilbert.cpp
src/nonbonded.cu
src/nonbonded_dense.cu
src/nonbonded_pairs.cu
src/nonbonded_all_pairs.cu
src/nonbonded_pair_list.cu
src/nonbonded_interaction_group.cu
src/neighborlist.cu
src/harmonic_bond.cu
src/harmonic_angle.cu
Expand All @@ -54,6 +55,7 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS
src/rmsd_align.cpp
src/summed_potential.cu
src/device_buffer.cu
src/kernels/k_nonbonded.cu
src/kernels/nonbonded_common.cu
)

Expand Down
Loading