Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NonbondedInteractionGroup potential, rename existing nonbonded potentials #578

Merged
merged 52 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
fe0fe9b
Rename NonbondedDense -> NonbondedAllPairs
mcwitt Jan 24, 2022
5848e88
Rename NonbondedPairs -> NonbondedPairList
mcwitt Jan 24, 2022
5bf391e
Copy NonbondedInteractionGroup from NonbondedAllPairs
mcwitt Jan 24, 2022
b915a4e
Update interface
mcwitt Feb 15, 2022
3b93f01
Copy indices to device
mcwitt Feb 15, 2022
0c06115
Add wrapper for NonbondedInteractionGroup
mcwitt Feb 9, 2022
6774f28
Add test for raising on invalid indices
mcwitt Feb 9, 2022
71c96e3
Check for empty index set
mcwitt Feb 15, 2022
980185e
Add basic correctness test
mcwitt Feb 9, 2022
945e62a
Implement potential
mcwitt Feb 15, 2022
b532e89
Skip kernel invocation if there are no interactions
mcwitt Feb 13, 2022
0f24785
Add offset to row atom indices in case NR != 0
mcwitt Feb 14, 2022
ecf76d8
Avoid "test_" prefix in fixtures
mcwitt Feb 16, 2022
fcd4199
Test with random non-contiguous row atom indices
mcwitt Feb 16, 2022
4f4bc51
Remove obsolete comments
mcwitt Feb 16, 2022
6185c4b
Define N as const
mcwitt Feb 16, 2022
efa4828
Move input validation to top
mcwitt Feb 16, 2022
3b3ea55
Remove guard
mcwitt Feb 16, 2022
d0f154a
Add test for case with zero interactions
mcwitt Feb 16, 2022
3eb51ce
Fix typo
mcwitt Feb 16, 2022
0cd9168
Return early if neighborlist is empty
mcwitt Feb 16, 2022
05b1b35
Attempt to clarify condition
mcwitt Feb 16, 2022
43fd522
Move stray comment, stream synchronization
mcwitt Feb 16, 2022
653fe4d
Remove print statements, make errors more informative
mcwitt Feb 16, 2022
a70a330
Add comments for clarity, remove extra placeholders
mcwitt Feb 16, 2022
690a57c
Remove obsolete comment
mcwitt Feb 16, 2022
1097ca4
Remove obsolete comment
mcwitt Feb 16, 2022
d728988
Add description to test
mcwitt Feb 16, 2022
eb07db3
Add test comparing with nonbonded_v3_on_specific_pairs
mcwitt Feb 16, 2022
3e36b51
Improve accuracy of comments
mcwitt Feb 16, 2022
2dfb18d
Rename test to something less confusing, add type annotation
mcwitt Feb 16, 2022
39907d6
Account for difference with zero LJ parameters in test
mcwitt Feb 17, 2022
561ce2f
Move jax reference potential out of test
mcwitt Feb 17, 2022
dd543ea
Test with multiple lambda values
mcwitt Feb 17, 2022
25fd4f6
Move cudaStreamSynchronize after cudaMemcpyAsync
mcwitt Feb 17, 2022
860dc9d
Update reference to only compute vdW force when eps_ij != 0
mcwitt Feb 17, 2022
b7ebc26
Rename coords -> conf
mcwitt Feb 17, 2022
89ee22f
Reorder tests
mcwitt Feb 17, 2022
a6486c1
Add consistency test applying constant shift to one group
mcwitt Feb 17, 2022
6163d3c
Rename row/col -> ligand/host in test
mcwitt Feb 17, 2022
824342d
Clean up docstrings
mcwitt Feb 17, 2022
a986a34
Adjust test tolerances
mcwitt Feb 17, 2022
91d1722
Replace itertools product with meshgrid
mcwitt Feb 18, 2022
efb100d
Remove unused import
mcwitt Feb 18, 2022
90d0cd5
Remove obsolete comment
mcwitt Feb 18, 2022
695f2c1
Remove no-op checks
mcwitt Feb 18, 2022
6c5598b
Use nonzero offset indices in consistency checks
mcwitt Feb 18, 2022
dbf4fbe
Use nonzero planes, offsets in correctness test
mcwitt Feb 18, 2022
6d199ca
Expose interpolated potential to Python API, add test
mcwitt Feb 18, 2022
22f9dc9
Remove unused import
mcwitt Feb 18, 2022
cdd37ff
Pass tuples to silence type warnings
mcwitt Feb 18, 2022
6d3ce6c
Merge branch 'master' into nonbonded-interaction-groups
mcwitt Feb 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
410 changes: 410 additions & 0 deletions tests/test_nonbonded_interaction_group.py

Large diffs are not rendered by default.

13 changes: 2 additions & 11 deletions tests/test_parameter_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,13 @@

config.update("jax_enable_x64", True)
import copy
import functools

import jax.numpy as jnp
import numpy as np
from common import GradientTest, prepare_water_system

from timemachine.lib import potentials


def interpolated_potential(conf, params, box, lamb, u_fn):
assert params.size % 2 == 0

CP = params.shape[0] // 2
new_params = (1 - lamb) * params[:CP] + lamb * params[CP:]

return u_fn(conf, new_params, box, lamb)
from timemachine.potentials import nonbonded


class TestInterpolatedPotential(GradientTest):
Expand Down Expand Up @@ -57,7 +48,7 @@ def test_nonbonded(self):

print("lambda", lamb, "cutoff", cutoff, "precision", precision, "xshape", coords.shape)

ref_interpolated_potential = functools.partial(interpolated_potential, u_fn=ref_potential)
ref_interpolated_potential = nonbonded.interpolated(ref_potential)

test_interpolated_potential = potentials.NonbondedInterpolated(*test_potential.args)

Expand Down
6 changes: 4 additions & 2 deletions timemachine/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS
src/gpu_utils.cu
src/vendored/hilbert.cpp
src/nonbonded.cu
src/nonbonded_dense.cu
src/nonbonded_pairs.cu
src/nonbonded_all_pairs.cu
src/nonbonded_pair_list.cu
src/nonbonded_interaction_group.cu
src/neighborlist.cu
src/harmonic_bond.cu
src/harmonic_angle.cu
Expand All @@ -54,6 +55,7 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS
src/rmsd_align.cpp
src/summed_potential.cu
src/device_buffer.cu
src/kernels/k_nonbonded.cu
src/kernels/nonbonded_common.cu
)

Expand Down
91 changes: 91 additions & 0 deletions timemachine/cpp/src/kernels/k_nonbonded.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "k_nonbonded.cuh"

void __global__ k_coords_to_kv(
const int N,
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals) {

const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (atom_idx >= N) {
return;
}

// these coords have to be centered
double bx = box[0 * 3 + 0];
double by = box[1 * 3 + 1];
double bz = box[2 * 3 + 2];

double binWidth = max(max(bx, by), bz) / 255.0;

double x = coords[atom_idx * 3 + 0];
double y = coords[atom_idx * 3 + 1];
double z = coords[atom_idx * 3 + 2];

x -= bx * floor(x / bx);
y -= by * floor(y / by);
z -= bz * floor(z / bz);

unsigned int bin_x = x / binWidth;
unsigned int bin_y = y / binWidth;
unsigned int bin_z = z / binWidth;

keys[atom_idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z];
// uncomment below if you want to preserve the atom ordering
// keys[atom_idx] = atom_idx;
vals[atom_idx] = atom_idx;
}

// TODO: DRY with k_coords_to_kv
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely DRY for this one (do this in a later PR is fine), I think you can simply implement the old version with atom_idxs = np.arange(N)

void __global__ k_coords_to_kv_gather(
const int N,
const unsigned int *atom_idxs,
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals) {

const int idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx >= N) {
return;
}

const int atom_idx = atom_idxs[idx];

// these coords have to be centered
double bx = box[0 * 3 + 0];
double by = box[1 * 3 + 1];
double bz = box[2 * 3 + 2];

double binWidth = max(max(bx, by), bz) / 255.0;

double x = coords[atom_idx * 3 + 0];
double y = coords[atom_idx * 3 + 1];
double z = coords[atom_idx * 3 + 2];

x -= bx * floor(x / bx);
y -= by * floor(y / by);
z -= bz * floor(z / bz);

unsigned int bin_x = x / binWidth;
unsigned int bin_y = y / binWidth;
unsigned int bin_z = z / binWidth;

keys[idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z];
// uncomment below if you want to preserve the atom ordering
// keys[idx] = atom_idx;
vals[idx] = atom_idx;
}

void __global__ k_arange(int N, unsigned int *arr) {
const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (atom_idx >= N) {
return;
}
arr[atom_idx] = atom_idx;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,26 @@
#include "nonbonded_common.cuh"
#include "surreal.cuh"

void __global__ k_arange(int N, unsigned int *arr);

// generate kv values from coordinates to be radix sorted
void __global__ k_coords_to_kv(
const int N,
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals) {

const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (atom_idx >= N) {
return;
}

// these coords have to be centered
double bx = box[0 * 3 + 0];
double by = box[1 * 3 + 1];
double bz = box[2 * 3 + 2];

double binWidth = max(max(bx, by), bz) / 255.0;
unsigned int *vals);

double x = coords[atom_idx * 3 + 0];
double y = coords[atom_idx * 3 + 1];
double z = coords[atom_idx * 3 + 2];

x -= bx * floor(x / bx);
y -= by * floor(y / by);
z -= bz * floor(z / bz);

unsigned int bin_x = x / binWidth;
unsigned int bin_y = y / binWidth;
unsigned int bin_z = z / binWidth;

keys[atom_idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z];
// uncomment below if you want to preserve the atom ordering
// keys[atom_idx] = atom_idx;
vals[atom_idx] = atom_idx;
}
// variant of k_coords_to_kv allowing the selection of a subset of coordinates
void __global__ k_coords_to_kv_gather(
const int N, // number of atoms in selection
const unsigned int *atom_idxs, // [N] indices of atoms to select
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals);

template <typename RealType>
void __global__ k_check_rebuild_box(const int N, const double *new_box, const double *old_box, int *rebuild) {
Expand Down Expand Up @@ -282,7 +263,8 @@ template <
bool COMPUTE_DU_DP>
// void __device__ __forceinline__ v_nonbonded_unified(
void __device__ v_nonbonded_unified(
const int N,
const int NC,
const int NR,
const double *__restrict__ coords,
const double *__restrict__ params, // [N]
const double *__restrict__ box,
Expand Down Expand Up @@ -316,6 +298,12 @@ void __device__ v_nonbonded_unified(
// int lambda_offset_i = atom_i_idx < N ? lambda_offset_idxs[atom_i_idx] : 0;
// int lambda_plane_i = atom_i_idx < N ? lambda_plane_idxs[atom_i_idx] : 0;

const int N = NC + NR;

if (NR != 0) {
atom_i_idx += NC;
}

RealType ci_x = atom_i_idx < N ? coords[atom_i_idx * 3 + 0] : 0;
RealType ci_y = atom_i_idx < N ? coords[atom_i_idx * 3 + 1] : 0;
RealType ci_z = atom_i_idx < N ? coords[atom_i_idx * 3 + 2] : 0;
Expand Down Expand Up @@ -348,15 +336,15 @@ void __device__ v_nonbonded_unified(
// int lambda_offset_j = atom_j_idx < N ? lambda_offset_idxs[atom_j_idx] : 0;
// int lambda_plane_j = atom_j_idx < N ? lambda_plane_idxs[atom_j_idx] : 0;

RealType cj_x = atom_j_idx < N ? coords[atom_j_idx * 3 + 0] : 0;
RealType cj_y = atom_j_idx < N ? coords[atom_j_idx * 3 + 1] : 0;
RealType cj_z = atom_j_idx < N ? coords[atom_j_idx * 3 + 2] : 0;
RealType cj_w = atom_j_idx < N ? coords_w[atom_j_idx] : 0;
RealType cj_x = atom_j_idx < NC ? coords[atom_j_idx * 3 + 0] : 0;
RealType cj_y = atom_j_idx < NC ? coords[atom_j_idx * 3 + 1] : 0;
RealType cj_z = atom_j_idx < NC ? coords[atom_j_idx * 3 + 2] : 0;
RealType cj_w = atom_j_idx < NC ? coords_w[atom_j_idx] : 0;

RealType dq_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 0] : 0;
RealType dsig_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 1] : 0;
RealType deps_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 2] : 0;
RealType dw_dl_j = atom_j_idx < N ? dw_dl[atom_j_idx] : 0;
RealType dq_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 0] : 0;
RealType dsig_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 1] : 0;
RealType deps_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 2] : 0;
RealType dw_dl_j = atom_j_idx < NC ? dw_dl[atom_j_idx] : 0;

unsigned long long gj_x = 0;
unsigned long long gj_y = 0;
Expand All @@ -366,9 +354,9 @@ void __device__ v_nonbonded_unified(
int lj_param_idx_sig_j = atom_j_idx * 3 + 1;
int lj_param_idx_eps_j = atom_j_idx * 3 + 2;

RealType qj = atom_j_idx < N ? params[charge_param_idx_j] : 0;
RealType sig_j = atom_j_idx < N ? params[lj_param_idx_sig_j] : 0;
RealType eps_j = atom_j_idx < N ? params[lj_param_idx_eps_j] : 0;
RealType qj = atom_j_idx < NC ? params[charge_param_idx_j] : 0;
RealType sig_j = atom_j_idx < NC ? params[lj_param_idx_sig_j] : 0;
RealType eps_j = atom_j_idx < NC ? params[lj_param_idx_eps_j] : 0;

unsigned long long g_qj = 0;
unsigned long long g_sigj = 0;
Expand Down Expand Up @@ -403,11 +391,17 @@ void __device__ v_nonbonded_unified(
d2ij += delta_w * delta_w;
}

const bool valid_ij =
atom_i_idx < N &&
((NR == 0) ? atom_i_idx < atom_j_idx && atom_j_idx < N // all-pairs case, only compute the upper tri
// 0 <= i < N, i < j < N
: atom_j_idx < NC); // ixn groups case, compute all pairwise ixns
// NC <= i < N, 0 <= j < NC

// (ytz): note that d2ij must be *strictly* less than cutoff_squared. This is because we set the
// non-interacting atoms to exactly real_cutoff*real_cutoff. This ensures that atoms who's 4th dimension
// is set to cutoff are non-interacting.
if (d2ij < cutoff_squared && atom_j_idx > atom_i_idx && atom_j_idx < N && atom_i_idx < N) {

if (d2ij < cutoff_squared && valid_ij) {
// electrostatics
RealType u;
RealType es_prefactor;
Expand Down Expand Up @@ -547,7 +541,8 @@ void __device__ v_nonbonded_unified(

template <typename RealType, bool COMPUTE_U, bool COMPUTE_DU_DX, bool COMPUTE_DU_DL, bool COMPUTE_DU_DP>
void __global__ k_nonbonded_unified(
const int N,
const int NC,
const int NR,
const double *__restrict__ coords,
const double *__restrict__ params, // [N]
const double *__restrict__ box,
Expand All @@ -567,6 +562,12 @@ void __global__ k_nonbonded_unified(
int row_block_idx = ixn_tiles[tile_idx];
int atom_i_idx = row_block_idx * 32 + threadIdx.x;

const int N = NC + NR;

if (NR != 0) {
atom_i_idx += NC;
}

RealType dq_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 0] : 0;
RealType dsig_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 1] : 0;
RealType deps_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 2] : 0;
Expand All @@ -587,7 +588,8 @@ void __global__ k_nonbonded_unified(

if (tile_is_vanilla) {
v_nonbonded_unified<RealType, 0, COMPUTE_U, COMPUTE_DU_DX, COMPUTE_DU_DL, COMPUTE_DU_DP>(
N,
NC,
NR,
coords,
params,
box,
Expand All @@ -604,7 +606,8 @@ void __global__ k_nonbonded_unified(
u_buffer);
} else {
v_nonbonded_unified<RealType, 1, COMPUTE_U, COMPUTE_DU_DX, COMPUTE_DU_DL, COMPUTE_DU_DP>(
N,
NC,
NR,
coords,
params,
box,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void __device__ __forceinline__ accumulate(unsigned long long *__restrict acc, u
}

template <typename RealType, bool Negated>
void __global__ k_nonbonded_pairs(
void __global__ k_nonbonded_pair_list(
const int M, // number of pairs
const double *__restrict__ coords,
const double *__restrict__ params,
Expand Down
8 changes: 4 additions & 4 deletions timemachine/cpp/src/nonbonded.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "nonbonded_dense.hpp"
#include "nonbonded_pairs.hpp"
#include "nonbonded_all_pairs.hpp"
#include "nonbonded_pair_list.hpp"
#include "potential.hpp"
#include <vector>

Expand All @@ -10,10 +10,10 @@ namespace timemachine {
template <typename RealType, bool Interpolated> class Nonbonded : public Potential {

private:
NonbondedDense<RealType, Interpolated> dense_;
NonbondedAllPairs<RealType, Interpolated> dense_;

static const bool Negated = true;
NonbondedPairs<RealType, Negated, Interpolated> exclusions_; // implement exclusions as negated NonbondedPairs
NonbondedPairList<RealType, Negated, Interpolated> exclusions_; // implement exclusions as negated NonbondedPairList

public:
// these are marked public but really only intended for testing.
Expand Down
Loading