Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NonbondedInteractionGroup potential, rename existing nonbonded potentials #578

Merged
merged 52 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
fe0fe9b
Rename NonbondedDense -> NonbondedAllPairs
mcwitt Jan 24, 2022
5848e88
Rename NonbondedPairs -> NonbondedPairList
mcwitt Jan 24, 2022
5bf391e
Copy NonbondedInteractionGroup from NonbondedAllPairs
mcwitt Jan 24, 2022
b915a4e
Update interface
mcwitt Feb 15, 2022
3b93f01
Copy indices to device
mcwitt Feb 15, 2022
0c06115
Add wrapper for NonbondedInteractionGroup
mcwitt Feb 9, 2022
6774f28
Add test for raising on invalid indices
mcwitt Feb 9, 2022
71c96e3
Check for empty index set
mcwitt Feb 15, 2022
980185e
Add basic correctness test
mcwitt Feb 9, 2022
945e62a
Implement potential
mcwitt Feb 15, 2022
b532e89
Skip kernel invocation if there are no interactions
mcwitt Feb 13, 2022
0f24785
Add offset to row atom indices in case NR != 0
mcwitt Feb 14, 2022
ecf76d8
Avoid "test_" prefix in fixtures
mcwitt Feb 16, 2022
fcd4199
Test with random non-contiguous row atom indices
mcwitt Feb 16, 2022
4f4bc51
Remove obsolete comments
mcwitt Feb 16, 2022
6185c4b
Define N as const
mcwitt Feb 16, 2022
efa4828
Move input validation to top
mcwitt Feb 16, 2022
3b3ea55
Remove guard
mcwitt Feb 16, 2022
d0f154a
Add test for case with zero interactions
mcwitt Feb 16, 2022
3eb51ce
Fix typo
mcwitt Feb 16, 2022
0cd9168
Return early if neighborlist is empty
mcwitt Feb 16, 2022
05b1b35
Attempt to clarify condition
mcwitt Feb 16, 2022
43fd522
Move stray comment, stream synchronization
mcwitt Feb 16, 2022
653fe4d
Remove print statements, make errors more informative
mcwitt Feb 16, 2022
a70a330
Add comments for clarity, remove extra placeholders
mcwitt Feb 16, 2022
690a57c
Remove obsolete comment
mcwitt Feb 16, 2022
1097ca4
Remove obsolete comment
mcwitt Feb 16, 2022
d728988
Add description to test
mcwitt Feb 16, 2022
eb07db3
Add test comparing with nonbonded_v3_on_specific_pairs
mcwitt Feb 16, 2022
3e36b51
Improve accuracy of comments
mcwitt Feb 16, 2022
2dfb18d
Rename test to something less confusing, add type annotation
mcwitt Feb 16, 2022
39907d6
Account for difference with zero LJ parameters in test
mcwitt Feb 17, 2022
561ce2f
Move jax reference potential out of test
mcwitt Feb 17, 2022
dd543ea
Test with multiple lambda values
mcwitt Feb 17, 2022
25fd4f6
Move cudaStreamSynchronize after cudaMemcpyAsync
mcwitt Feb 17, 2022
860dc9d
Update reference to only compute vdW force when eps_ij != 0
mcwitt Feb 17, 2022
b7ebc26
Rename coords -> conf
mcwitt Feb 17, 2022
89ee22f
Reorder tests
mcwitt Feb 17, 2022
a6486c1
Add consistency test applying constant shift to one group
mcwitt Feb 17, 2022
6163d3c
Rename row/col -> ligand/host in test
mcwitt Feb 17, 2022
824342d
Clean up docstrings
mcwitt Feb 17, 2022
a986a34
Adjust test tolerances
mcwitt Feb 17, 2022
91d1722
Replace itertools product with meshgrid
mcwitt Feb 18, 2022
efb100d
Remove unused import
mcwitt Feb 18, 2022
90d0cd5
Remove obsolete comment
mcwitt Feb 18, 2022
695f2c1
Remove no-op checks
mcwitt Feb 18, 2022
6c5598b
Use nonzero offset indices in consistency checks
mcwitt Feb 18, 2022
dbf4fbe
Use nonzero planes, offsets in correctness test
mcwitt Feb 18, 2022
6d199ca
Expose interpolated potential to Python API, add test
mcwitt Feb 18, 2022
22f9dc9
Remove unused import
mcwitt Feb 18, 2022
cdd37ff
Pass tuples to silence type warnings
mcwitt Feb 18, 2022
6d3ce6c
Merge branch 'master' into nonbonded-interaction-groups
mcwitt Feb 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions tests/test_nonbonded_interaction_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import jax

jax.config.update("jax_enable_x64", True)

import numpy as np
import pytest
from common import GradientTest, prepare_reference_nonbonded
from simtk.openmm import app

from timemachine.fe.utils import to_md_units
from timemachine.ff.handlers import openmm_deserializer
from timemachine.lib import potentials
from timemachine.lib.potentials import NonbondedInteractionGroup


@pytest.fixture(autouse=True)
def set_random_seed():
np.random.seed(2021)
yield


@pytest.fixture
def test_system():
maxentile marked this conversation as resolved.
Show resolved Hide resolved
pdb_path = "tests/data/5dfr_solv_equil.pdb"
host_pdb = app.PDBFile(pdb_path)
ff = app.ForceField("amber99sbildn.xml", "tip3p.xml")
return (
ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False),
host_pdb.positions,
host_pdb.topology.getPeriodicBoxVectors(),
)


@pytest.fixture
def ref_nonbonded_potential(test_system):
host_system, _, _ = test_system
host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0)

nonbonded_fn = None
for f in host_fns:
if isinstance(f, potentials.Nonbonded):
nonbonded_fn = f

return nonbonded_fn


@pytest.fixture
def test_conf(test_system):
_, host_coords, _ = test_system
return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_coords])


@pytest.fixture
def test_box(test_system):
_, _, box = test_system
return np.asarray(box / box.unit)


@pytest.mark.parametrize("beta", [2.0])
@pytest.mark.parametrize("cutoff", [1.1])
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)])
@pytest.mark.parametrize("num_row_atoms", [1, 15])
@pytest.mark.parametrize("num_atoms", [33, 231, 1050])
def test_nonbonded_interaction_group_correctness(
maxentile marked this conversation as resolved.
Show resolved Hide resolved
num_atoms, num_row_atoms, precision, rtol, atol, cutoff, beta, ref_nonbonded_potential, test_conf, test_box
):

test_conf = test_conf[:num_atoms]
test_params = ref_nonbonded_potential.params[:num_atoms, :]
test_lambda_offset_idxs = np.zeros(num_atoms, dtype=np.int32)

def make_reference_nonbonded(lambda_plane_idxs):
return prepare_reference_nonbonded(
params=test_params,
exclusion_idxs=np.array([], dtype=np.int32),
scales=np.zeros((0, 2), dtype=np.float64),
lambda_plane_idxs=lambda_plane_idxs,
lambda_offset_idxs=test_lambda_offset_idxs,
beta=beta,
cutoff=cutoff,
)

ref_allpairs = make_reference_nonbonded(np.zeros(num_atoms, dtype=np.int32))

num_col_atoms = num_atoms - num_row_atoms

ref_allpairs_minus_ixngroups = make_reference_nonbonded(
np.concatenate((np.zeros(num_row_atoms, dtype=np.int32), np.ones(num_col_atoms, dtype=np.int32)))
)

def ref_ixngroups(*args):
return ref_allpairs(*args) - ref_allpairs_minus_ixngroups(*args)

test_ixngroups = NonbondedInteractionGroup(
np.arange(0, num_row_atoms, dtype=np.int32),
maxentile marked this conversation as resolved.
Show resolved Hide resolved
np.zeros(num_atoms, dtype=np.int32), # lambda plane indices
test_lambda_offset_idxs,
beta,
cutoff,
)

GradientTest().compare_forces(
test_conf,
test_params,
test_box,
lamb=0.1,
badisa marked this conversation as resolved.
Show resolved Hide resolved
ref_potential=ref_ixngroups,
test_potential=test_ixngroups,
rtol=rtol,
atol=atol,
precision=precision,
)


def test_nonbonded_interaction_group_invalid_indices():
def make_potential(row_atom_idxs, num_atoms):
lambda_plane_idxs = [0] * num_atoms
lambda_offset_idxs = [0] * num_atoms
return NonbondedInteractionGroup(row_atom_idxs, lambda_plane_idxs, lambda_offset_idxs, 1.0, 1.0).unbound_impl(
np.float64
)

with pytest.raises(RuntimeError) as e:
make_potential([], 1)
assert "row_atom_idxs must be nonempty" in str(e)

with pytest.raises(RuntimeError) as e:
make_potential([1, 1], 3)
assert "atom indices must be unique" in str(e)
6 changes: 4 additions & 2 deletions timemachine/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS
src/gpu_utils.cu
src/vendored/hilbert.cpp
src/nonbonded.cu
src/nonbonded_dense.cu
src/nonbonded_pairs.cu
src/nonbonded_all_pairs.cu
src/nonbonded_pair_list.cu
src/nonbonded_interaction_group.cu
src/neighborlist.cu
src/harmonic_bond.cu
src/harmonic_angle.cu
Expand All @@ -54,6 +55,7 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS
src/rmsd_align.cpp
src/summed_potential.cu
src/device_buffer.cu
src/kernels/k_nonbonded.cu
src/kernels/nonbonded_common.cu
)

Expand Down
91 changes: 91 additions & 0 deletions timemachine/cpp/src/kernels/k_nonbonded.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "k_nonbonded.cuh"

void __global__ k_coords_to_kv(
const int N,
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals) {

const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (atom_idx >= N) {
return;
}

// these coords have to be centered
double bx = box[0 * 3 + 0];
double by = box[1 * 3 + 1];
double bz = box[2 * 3 + 2];

double binWidth = max(max(bx, by), bz) / 255.0;

double x = coords[atom_idx * 3 + 0];
double y = coords[atom_idx * 3 + 1];
double z = coords[atom_idx * 3 + 2];

x -= bx * floor(x / bx);
y -= by * floor(y / by);
z -= bz * floor(z / bz);

unsigned int bin_x = x / binWidth;
unsigned int bin_y = y / binWidth;
unsigned int bin_z = z / binWidth;

keys[atom_idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z];
// uncomment below if you want to preserve the atom ordering
// keys[atom_idx] = atom_idx;
vals[atom_idx] = atom_idx;
}

// TODO: DRY with k_coords_to_kv
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely DRY for this one (do this in a later PR is fine), I think you can simply implement the old version with atom_idxs = np.arange(N)

void __global__ k_coords_to_kv_gather(
const int N,
const unsigned int *atom_idxs,
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals) {

const int idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx >= N) {
return;
}

const int atom_idx = atom_idxs[idx];

// these coords have to be centered
double bx = box[0 * 3 + 0];
double by = box[1 * 3 + 1];
double bz = box[2 * 3 + 2];

double binWidth = max(max(bx, by), bz) / 255.0;

double x = coords[atom_idx * 3 + 0];
double y = coords[atom_idx * 3 + 1];
double z = coords[atom_idx * 3 + 2];

x -= bx * floor(x / bx);
y -= by * floor(y / by);
z -= bz * floor(z / bz);

unsigned int bin_x = x / binWidth;
unsigned int bin_y = y / binWidth;
unsigned int bin_z = z / binWidth;

keys[idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z];
// uncomment below if you want to preserve the atom ordering
// keys[idx] = atom_idx;
vals[idx] = atom_idx;
}

void __global__ k_arange(int N, unsigned int *arr) {
const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (atom_idx >= N) {
return;
}
arr[atom_idx] = atom_idx;
}
Loading