Skip to content

Commit

Permalink
Add NonbondedInteractionGroup potential, rename existing nonbonded po…
Browse files Browse the repository at this point in the history
…tentials (#578)
  • Loading branch information
mcwitt authored Feb 18, 2022
1 parent 515b49e commit d378e88
Show file tree
Hide file tree
Showing 16 changed files with 1,369 additions and 138 deletions.
410 changes: 410 additions & 0 deletions tests/test_nonbonded_interaction_group.py

Large diffs are not rendered by default.

13 changes: 2 additions & 11 deletions tests/test_parameter_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,13 @@

config.update("jax_enable_x64", True)
import copy
import functools

import jax.numpy as jnp
import numpy as np
from common import GradientTest, prepare_water_system

from timemachine.lib import potentials


def interpolated_potential(conf, params, box, lamb, u_fn):
assert params.size % 2 == 0

CP = params.shape[0] // 2
new_params = (1 - lamb) * params[:CP] + lamb * params[CP:]

return u_fn(conf, new_params, box, lamb)
from timemachine.potentials import nonbonded


class TestInterpolatedPotential(GradientTest):
Expand Down Expand Up @@ -57,7 +48,7 @@ def test_nonbonded(self):

print("lambda", lamb, "cutoff", cutoff, "precision", precision, "xshape", coords.shape)

ref_interpolated_potential = functools.partial(interpolated_potential, u_fn=ref_potential)
ref_interpolated_potential = nonbonded.interpolated(ref_potential)

test_interpolated_potential = potentials.NonbondedInterpolated(*test_potential.args)

Expand Down
6 changes: 4 additions & 2 deletions timemachine/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS
src/gpu_utils.cu
src/vendored/hilbert.cpp
src/nonbonded.cu
src/nonbonded_dense.cu
src/nonbonded_pairs.cu
src/nonbonded_all_pairs.cu
src/nonbonded_pair_list.cu
src/nonbonded_interaction_group.cu
src/neighborlist.cu
src/harmonic_bond.cu
src/harmonic_angle.cu
Expand All @@ -54,6 +55,7 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS
src/rmsd_align.cpp
src/summed_potential.cu
src/device_buffer.cu
src/kernels/k_nonbonded.cu
src/kernels/nonbonded_common.cu
)

Expand Down
91 changes: 91 additions & 0 deletions timemachine/cpp/src/kernels/k_nonbonded.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "k_nonbonded.cuh"

void __global__ k_coords_to_kv(
const int N,
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals) {

const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (atom_idx >= N) {
return;
}

// these coords have to be centered
double bx = box[0 * 3 + 0];
double by = box[1 * 3 + 1];
double bz = box[2 * 3 + 2];

double binWidth = max(max(bx, by), bz) / 255.0;

double x = coords[atom_idx * 3 + 0];
double y = coords[atom_idx * 3 + 1];
double z = coords[atom_idx * 3 + 2];

x -= bx * floor(x / bx);
y -= by * floor(y / by);
z -= bz * floor(z / bz);

unsigned int bin_x = x / binWidth;
unsigned int bin_y = y / binWidth;
unsigned int bin_z = z / binWidth;

keys[atom_idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z];
// uncomment below if you want to preserve the atom ordering
// keys[atom_idx] = atom_idx;
vals[atom_idx] = atom_idx;
}

// TODO: DRY with k_coords_to_kv
void __global__ k_coords_to_kv_gather(
const int N,
const unsigned int *atom_idxs,
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals) {

const int idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx >= N) {
return;
}

const int atom_idx = atom_idxs[idx];

// these coords have to be centered
double bx = box[0 * 3 + 0];
double by = box[1 * 3 + 1];
double bz = box[2 * 3 + 2];

double binWidth = max(max(bx, by), bz) / 255.0;

double x = coords[atom_idx * 3 + 0];
double y = coords[atom_idx * 3 + 1];
double z = coords[atom_idx * 3 + 2];

x -= bx * floor(x / bx);
y -= by * floor(y / by);
z -= bz * floor(z / bz);

unsigned int bin_x = x / binWidth;
unsigned int bin_y = y / binWidth;
unsigned int bin_z = z / binWidth;

keys[idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z];
// uncomment below if you want to preserve the atom ordering
// keys[idx] = atom_idx;
vals[idx] = atom_idx;
}

void __global__ k_arange(int N, unsigned int *arr) {
const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (atom_idx >= N) {
return;
}
arr[atom_idx] = atom_idx;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,26 @@
#include "nonbonded_common.cuh"
#include "surreal.cuh"

void __global__ k_arange(int N, unsigned int *arr);

// generate kv values from coordinates to be radix sorted
void __global__ k_coords_to_kv(
const int N,
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals) {

const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (atom_idx >= N) {
return;
}

// these coords have to be centered
double bx = box[0 * 3 + 0];
double by = box[1 * 3 + 1];
double bz = box[2 * 3 + 2];

double binWidth = max(max(bx, by), bz) / 255.0;
unsigned int *vals);

double x = coords[atom_idx * 3 + 0];
double y = coords[atom_idx * 3 + 1];
double z = coords[atom_idx * 3 + 2];

x -= bx * floor(x / bx);
y -= by * floor(y / by);
z -= bz * floor(z / bz);

unsigned int bin_x = x / binWidth;
unsigned int bin_y = y / binWidth;
unsigned int bin_z = z / binWidth;

keys[atom_idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z];
// uncomment below if you want to preserve the atom ordering
// keys[atom_idx] = atom_idx;
vals[atom_idx] = atom_idx;
}
// variant of k_coords_to_kv allowing the selection of a subset of coordinates
void __global__ k_coords_to_kv_gather(
const int N, // number of atoms in selection
const unsigned int *atom_idxs, // [N] indices of atoms to select
const double *coords,
const double *box,
const unsigned int *bin_to_idx,
unsigned int *keys,
unsigned int *vals);

template <typename RealType>
void __global__ k_check_rebuild_box(const int N, const double *new_box, const double *old_box, int *rebuild) {
Expand Down Expand Up @@ -282,7 +263,8 @@ template <
bool COMPUTE_DU_DP>
// void __device__ __forceinline__ v_nonbonded_unified(
void __device__ v_nonbonded_unified(
const int N,
const int NC,
const int NR,
const double *__restrict__ coords,
const double *__restrict__ params, // [N]
const double *__restrict__ box,
Expand Down Expand Up @@ -316,6 +298,12 @@ void __device__ v_nonbonded_unified(
// int lambda_offset_i = atom_i_idx < N ? lambda_offset_idxs[atom_i_idx] : 0;
// int lambda_plane_i = atom_i_idx < N ? lambda_plane_idxs[atom_i_idx] : 0;

const int N = NC + NR;

if (NR != 0) {
atom_i_idx += NC;
}

RealType ci_x = atom_i_idx < N ? coords[atom_i_idx * 3 + 0] : 0;
RealType ci_y = atom_i_idx < N ? coords[atom_i_idx * 3 + 1] : 0;
RealType ci_z = atom_i_idx < N ? coords[atom_i_idx * 3 + 2] : 0;
Expand Down Expand Up @@ -348,15 +336,15 @@ void __device__ v_nonbonded_unified(
// int lambda_offset_j = atom_j_idx < N ? lambda_offset_idxs[atom_j_idx] : 0;
// int lambda_plane_j = atom_j_idx < N ? lambda_plane_idxs[atom_j_idx] : 0;

RealType cj_x = atom_j_idx < N ? coords[atom_j_idx * 3 + 0] : 0;
RealType cj_y = atom_j_idx < N ? coords[atom_j_idx * 3 + 1] : 0;
RealType cj_z = atom_j_idx < N ? coords[atom_j_idx * 3 + 2] : 0;
RealType cj_w = atom_j_idx < N ? coords_w[atom_j_idx] : 0;
RealType cj_x = atom_j_idx < NC ? coords[atom_j_idx * 3 + 0] : 0;
RealType cj_y = atom_j_idx < NC ? coords[atom_j_idx * 3 + 1] : 0;
RealType cj_z = atom_j_idx < NC ? coords[atom_j_idx * 3 + 2] : 0;
RealType cj_w = atom_j_idx < NC ? coords_w[atom_j_idx] : 0;

RealType dq_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 0] : 0;
RealType dsig_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 1] : 0;
RealType deps_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 2] : 0;
RealType dw_dl_j = atom_j_idx < N ? dw_dl[atom_j_idx] : 0;
RealType dq_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 0] : 0;
RealType dsig_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 1] : 0;
RealType deps_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 2] : 0;
RealType dw_dl_j = atom_j_idx < NC ? dw_dl[atom_j_idx] : 0;

unsigned long long gj_x = 0;
unsigned long long gj_y = 0;
Expand All @@ -366,9 +354,9 @@ void __device__ v_nonbonded_unified(
int lj_param_idx_sig_j = atom_j_idx * 3 + 1;
int lj_param_idx_eps_j = atom_j_idx * 3 + 2;

RealType qj = atom_j_idx < N ? params[charge_param_idx_j] : 0;
RealType sig_j = atom_j_idx < N ? params[lj_param_idx_sig_j] : 0;
RealType eps_j = atom_j_idx < N ? params[lj_param_idx_eps_j] : 0;
RealType qj = atom_j_idx < NC ? params[charge_param_idx_j] : 0;
RealType sig_j = atom_j_idx < NC ? params[lj_param_idx_sig_j] : 0;
RealType eps_j = atom_j_idx < NC ? params[lj_param_idx_eps_j] : 0;

unsigned long long g_qj = 0;
unsigned long long g_sigj = 0;
Expand Down Expand Up @@ -403,11 +391,17 @@ void __device__ v_nonbonded_unified(
d2ij += delta_w * delta_w;
}

const bool valid_ij =
atom_i_idx < N &&
((NR == 0) ? atom_i_idx < atom_j_idx && atom_j_idx < N // all-pairs case, only compute the upper tri
// 0 <= i < N, i < j < N
: atom_j_idx < NC); // ixn groups case, compute all pairwise ixns
// NC <= i < N, 0 <= j < NC

// (ytz): note that d2ij must be *strictly* less than cutoff_squared. This is because we set the
// non-interacting atoms to exactly real_cutoff*real_cutoff. This ensures that atoms who's 4th dimension
// is set to cutoff are non-interacting.
if (d2ij < cutoff_squared && atom_j_idx > atom_i_idx && atom_j_idx < N && atom_i_idx < N) {

if (d2ij < cutoff_squared && valid_ij) {
// electrostatics
RealType u;
RealType es_prefactor;
Expand Down Expand Up @@ -547,7 +541,8 @@ void __device__ v_nonbonded_unified(

template <typename RealType, bool COMPUTE_U, bool COMPUTE_DU_DX, bool COMPUTE_DU_DL, bool COMPUTE_DU_DP>
void __global__ k_nonbonded_unified(
const int N,
const int NC,
const int NR,
const double *__restrict__ coords,
const double *__restrict__ params, // [N]
const double *__restrict__ box,
Expand All @@ -567,6 +562,12 @@ void __global__ k_nonbonded_unified(
int row_block_idx = ixn_tiles[tile_idx];
int atom_i_idx = row_block_idx * 32 + threadIdx.x;

const int N = NC + NR;

if (NR != 0) {
atom_i_idx += NC;
}

RealType dq_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 0] : 0;
RealType dsig_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 1] : 0;
RealType deps_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 2] : 0;
Expand All @@ -587,7 +588,8 @@ void __global__ k_nonbonded_unified(

if (tile_is_vanilla) {
v_nonbonded_unified<RealType, 0, COMPUTE_U, COMPUTE_DU_DX, COMPUTE_DU_DL, COMPUTE_DU_DP>(
N,
NC,
NR,
coords,
params,
box,
Expand All @@ -604,7 +606,8 @@ void __global__ k_nonbonded_unified(
u_buffer);
} else {
v_nonbonded_unified<RealType, 1, COMPUTE_U, COMPUTE_DU_DX, COMPUTE_DU_DL, COMPUTE_DU_DP>(
N,
NC,
NR,
coords,
params,
box,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void __device__ __forceinline__ accumulate(unsigned long long *__restrict acc, u
}

template <typename RealType, bool Negated>
void __global__ k_nonbonded_pairs(
void __global__ k_nonbonded_pair_list(
const int M, // number of pairs
const double *__restrict__ coords,
const double *__restrict__ params,
Expand Down
8 changes: 4 additions & 4 deletions timemachine/cpp/src/nonbonded.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "nonbonded_dense.hpp"
#include "nonbonded_pairs.hpp"
#include "nonbonded_all_pairs.hpp"
#include "nonbonded_pair_list.hpp"
#include "potential.hpp"
#include <vector>

Expand All @@ -10,10 +10,10 @@ namespace timemachine {
template <typename RealType, bool Interpolated> class Nonbonded : public Potential {

private:
NonbondedDense<RealType, Interpolated> dense_;
NonbondedAllPairs<RealType, Interpolated> dense_;

static const bool Negated = true;
NonbondedPairs<RealType, Negated, Interpolated> exclusions_; // implement exclusions as negated NonbondedPairs
NonbondedPairList<RealType, Negated, Interpolated> exclusions_; // implement exclusions as negated NonbondedPairList

public:
// these are marked public but really only intended for testing.
Expand Down
Loading

0 comments on commit d378e88

Please sign in to comment.