Skip to content

Commit

Permalink
Add offset to row atom indices in case NR != 0
Browse files Browse the repository at this point in the history
In the interacting groups case (NR != 0), the row atom indices are
offset in the sorted arrays (e.g. d_sorted_x) by the number of column
atoms
  • Loading branch information
mcwitt committed Feb 15, 2022
1 parent b532e89 commit 0f24785
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 19 deletions.
49 changes: 33 additions & 16 deletions timemachine/cpp/src/kernels/k_nonbonded.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ template <
bool COMPUTE_DU_DP>
// void __device__ __forceinline__ v_nonbonded_unified(
void __device__ v_nonbonded_unified(
const int N,
const int NC,
const int NR,
const double *__restrict__ coords,
const double *__restrict__ params, // [N]
const double *__restrict__ box,
Expand Down Expand Up @@ -297,6 +298,12 @@ void __device__ v_nonbonded_unified(
// int lambda_offset_i = atom_i_idx < N ? lambda_offset_idxs[atom_i_idx] : 0;
// int lambda_plane_i = atom_i_idx < N ? lambda_plane_idxs[atom_i_idx] : 0;

int N = NC + NR;

if (NR != 0) {
atom_i_idx += NC;
}

RealType ci_x = atom_i_idx < N ? coords[atom_i_idx * 3 + 0] : 0;
RealType ci_y = atom_i_idx < N ? coords[atom_i_idx * 3 + 1] : 0;
RealType ci_z = atom_i_idx < N ? coords[atom_i_idx * 3 + 2] : 0;
Expand Down Expand Up @@ -329,15 +336,15 @@ void __device__ v_nonbonded_unified(
// int lambda_offset_j = atom_j_idx < N ? lambda_offset_idxs[atom_j_idx] : 0;
// int lambda_plane_j = atom_j_idx < N ? lambda_plane_idxs[atom_j_idx] : 0;

RealType cj_x = atom_j_idx < N ? coords[atom_j_idx * 3 + 0] : 0;
RealType cj_y = atom_j_idx < N ? coords[atom_j_idx * 3 + 1] : 0;
RealType cj_z = atom_j_idx < N ? coords[atom_j_idx * 3 + 2] : 0;
RealType cj_w = atom_j_idx < N ? coords_w[atom_j_idx] : 0;
RealType cj_x = atom_j_idx < NC ? coords[atom_j_idx * 3 + 0] : 0;
RealType cj_y = atom_j_idx < NC ? coords[atom_j_idx * 3 + 1] : 0;
RealType cj_z = atom_j_idx < NC ? coords[atom_j_idx * 3 + 2] : 0;
RealType cj_w = atom_j_idx < NC ? coords_w[atom_j_idx] : 0;

RealType dq_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 0] : 0;
RealType dsig_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 1] : 0;
RealType deps_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 2] : 0;
RealType dw_dl_j = atom_j_idx < N ? dw_dl[atom_j_idx] : 0;
RealType dq_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 0] : 0;
RealType dsig_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 1] : 0;
RealType deps_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 2] : 0;
RealType dw_dl_j = atom_j_idx < NC ? dw_dl[atom_j_idx] : 0;

unsigned long long gj_x = 0;
unsigned long long gj_y = 0;
Expand All @@ -347,9 +354,9 @@ void __device__ v_nonbonded_unified(
int lj_param_idx_sig_j = atom_j_idx * 3 + 1;
int lj_param_idx_eps_j = atom_j_idx * 3 + 2;

RealType qj = atom_j_idx < N ? params[charge_param_idx_j] : 0;
RealType sig_j = atom_j_idx < N ? params[lj_param_idx_sig_j] : 0;
RealType eps_j = atom_j_idx < N ? params[lj_param_idx_eps_j] : 0;
RealType qj = atom_j_idx < NC ? params[charge_param_idx_j] : 0;
RealType sig_j = atom_j_idx < NC ? params[lj_param_idx_sig_j] : 0;
RealType eps_j = atom_j_idx < NC ? params[lj_param_idx_eps_j] : 0;

unsigned long long g_qj = 0;
unsigned long long g_sigj = 0;
Expand Down Expand Up @@ -387,7 +394,8 @@ void __device__ v_nonbonded_unified(
// (ytz): note that d2ij must be *strictly* less than cutoff_squared. This is because we set the
// non-interacting atoms to exactly real_cutoff*real_cutoff. This ensures that atoms who's 4th dimension
// is set to cutoff are non-interacting.
if (d2ij < cutoff_squared && atom_j_idx > atom_i_idx && atom_j_idx < N && atom_i_idx < N) {
if (d2ij < cutoff_squared && atom_i_idx < N &&
(NR == 0 && atom_j_idx > atom_i_idx && atom_j_idx < N || NR != 0 && atom_j_idx < NC)) {

// electrostatics
RealType u;
Expand Down Expand Up @@ -528,7 +536,8 @@ void __device__ v_nonbonded_unified(

template <typename RealType, bool COMPUTE_U, bool COMPUTE_DU_DX, bool COMPUTE_DU_DL, bool COMPUTE_DU_DP>
void __global__ k_nonbonded_unified(
const int N,
const int NC,
const int NR,
const double *__restrict__ coords,
const double *__restrict__ params, // [N]
const double *__restrict__ box,
Expand All @@ -548,6 +557,12 @@ void __global__ k_nonbonded_unified(
int row_block_idx = ixn_tiles[tile_idx];
int atom_i_idx = row_block_idx * 32 + threadIdx.x;

int N = NC + NR;

if (NR != 0) {
atom_i_idx += NC;
}

RealType dq_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 0] : 0;
RealType dsig_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 1] : 0;
RealType deps_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 2] : 0;
Expand All @@ -568,7 +583,8 @@ void __global__ k_nonbonded_unified(

if (tile_is_vanilla) {
v_nonbonded_unified<RealType, 0, COMPUTE_U, COMPUTE_DU_DX, COMPUTE_DU_DL, COMPUTE_DU_DP>(
N,
NC,
NR,
coords,
params,
box,
Expand All @@ -585,7 +601,8 @@ void __global__ k_nonbonded_unified(
u_buffer);
} else {
v_nonbonded_unified<RealType, 1, COMPUTE_U, COMPUTE_DU_DX, COMPUTE_DU_DL, COMPUTE_DU_DP>(
N,
NC,
NR,
coords,
params,
box,
Expand Down
1 change: 1 addition & 0 deletions timemachine/cpp/src/nonbonded_all_pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ void NonbondedAllPairs<RealType, Interpolated>::execute_device(

kernel_ptrs_[kernel_idx]<<<p_ixn_count_[0], tpb, 0, stream>>>(
N,
0,
d_sorted_x_,
d_sorted_p_,
d_box,
Expand Down
3 changes: 2 additions & 1 deletion timemachine/cpp/src/nonbonded_all_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
namespace timemachine {

typedef void (*k_nonbonded_fn)(
const int N,
const int NC,
const int NR,
const double *__restrict__ coords,
const double *__restrict__ params, // [N]
const double *__restrict__ box,
Expand Down
3 changes: 2 additions & 1 deletion timemachine/cpp/src/nonbonded_interaction_group.cu
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ void NonbondedInteractionGroup<RealType, Interpolated>::execute_device(
if (p_ixn_count_[0] > 0) {

kernel_ptrs_[kernel_idx]<<<p_ixn_count_[0], tpb, 0, stream>>>(
N,
NC_,
NR_,
d_sorted_x_,
d_sorted_p_,
d_box,
Expand Down
3 changes: 2 additions & 1 deletion timemachine/cpp/src/nonbonded_interaction_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
namespace timemachine {

typedef void (*k_nonbonded_fn)(
const int N,
const int NC,
const int NR,
const double *__restrict__ coords,
const double *__restrict__ params, // [N]
const double *__restrict__ box,
Expand Down

0 comments on commit 0f24785

Please sign in to comment.