-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/neighborlist on host ligand #545
Conversation
@@ -303,12 +309,12 @@ void __global__ k_find_blocks_with_ixns( | |||
atom_box_dy = max(static_cast<RealType>(0.0), fabs(atom_box_dy) - col_bb_ext_y); | |||
atom_box_dz = max(static_cast<RealType>(0.0), fabs(atom_box_dz) - col_bb_ext_z); | |||
|
|||
bool check_column_atoms = | |||
atom_i_idx < K && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This additional check is important, without this check you get junk interactions and the test fails.. Previously we were getting lucky with positions beyond the number of atoms we have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call!
@@ -1,9 +1,11 @@ | |||
#pragma once | |||
|
|||
__device__ __host__ double sign(double a) { return (a > 0) - (a < 0); } | |||
int __forceinline__ ceil_divide(int x, int y) { return (x + y - 1) / y; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like this to be more widely used, but for now I just used it for the code that relate to these changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be okay if we kept the old sign code as well? Not sure why it was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its only used by the GBSA potential which isn't currently compiled. But makes sense to keep it if we don't also remove the GBSA potential.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, these are really just jvp rules for autodiff. Let's keep them for now since they may be useful later on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took a first pass and left some surface-level comments!
|
||
cudaDeviceSynchronize(); | ||
const int B = this->B(); //(N+32-1)/32; | ||
const int tpb = 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not specific to this PR, but the definition tpb = 32
seems replicated in a lot of places - would it make sense to declare as a global (or at least file-level) constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We definitely should and its made worse that in some places tpb is used as a substitute for block size rather than warpsize
timemachine/cpp/src/math_utils.cuh
Outdated
__device__ __host__ float sign(Surreal<float> a) { return (a.real > 0) - (a.real < 0); } | ||
// __device__ __host__ double sign(Surreal<double> a) { return (a.real > 0) - (a.real < 0); } | ||
|
||
// __device__ __host__ float sign(Surreal<float> a) { return (a.real > 0) - (a.real < 0); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was commenting these out intentional? Would it make sense to just remove them if unused?
timemachine/cpp/src/neighborlist.cu
Outdated
const int N, // Number of atoms in column | ||
const int K, // Number of atoms in row |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This naming feels a bit strange to me, since I'm used to N
being the total number. Is this conventional, or could it make sense to give a different name to the number of atoms in a column? (it's possible this will make more sense to me once I have a better understanding of what's going on here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find that N is just a size of 'something', I can change them more broadly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I also got tripped by this while reading the PR, maybe we can rename N to NR and K to NC (num_row and num_col atoms)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great in my initial testing!
And can yield a big speedup in cases where we might want compute a batch of host-ligand neighborlists on stored snapshots (30x speed-up on the test system, from ~70 snapshots / second to ~2000 snapshots / second: https://gist.github.com/maxentile/fbaf654a7f21c05adce3b04145c43866 )
timemachine/cpp/src/neighborlist.cu
Outdated
// assert(N == N_); | ||
std::vector<std::vector<int>> | ||
Neighborlist<RealType>::get_nblist_host(int N, const double *h_coords, const double *h_box, const double cutoff) { | ||
// Call into the impl that takes colum and row coords |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo colum
timemachine/cpp/src/neighborlist.cu
Outdated
@@ -8,23 +8,38 @@ | |||
|
|||
namespace timemachine { | |||
|
|||
template <typename RealType> Neighborlist<RealType>::Neighborlist(int N) : N_(N) { | |||
template <typename RealType> | |||
Neighborlist<RealType>::Neighborlist(const int N, const int K) : N_(N), K_(K), compute_full_matrix_(K > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we change compute_full_matrix_ into a function? i.e. if K_ == 0 compute_full_matrix returns True (as opposed to storing it as a separate variable that may get out of sync during a later refactor)
This probably is also the bare minimum speed-up as well, as there just isn't enough work to saturate the GPU for something like this to proceed one ligand at a time. In fact, I think this code as-as can probably support the batch use case (without resorting to streams) where we can aggregate row atoms so they represent multiple conformations of the same ligand, and we emit a set of tiles from this neighborlist that can compute ligand-host energies for multiple-conformations all at once. |
@badisa Can we add a test case where we use the neighborlist to compute the U_HH interactions. Let NR be the number of row atoms, NC be the number of column atoms, and NT the total number of atoms in the complete system. Set NR to some number < NT, and NC to zero. I think right now we only test the case where NR < NT, NC > 0 (U_HL), and NR = NT, and NC = 0 (U original), but would be good to also test NR < NT, NC = 0 (U_HH) |
@proteneer How is this different from the previous test that computes the ixns of a water box? You would just not include the ligand in the coordinates passed as the columns. https://github.com/proteneer/timemachine/blob/master/tests/test_nblist.py#L69 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally LGTM, save for some minor comments. Would like to see the test mentioned on using it for U_HH but probably isn't a blocker for merge.
} | ||
if (D != 3) { | ||
throw std::runtime_error("D != 3"); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a consistency check that if K==0 then col_ptr is null, and if K>0, then col_ptr is not null?
Hm never mind I think I see how you'd probably use this in a nonbonded.cu later on! So we'd simply rely on the caller to prepare the sizes of the input arrays directly. |
Never mind - test for U_HH isn't needed. This PR is great! |
Agreed! The speed-up from later applying this batch optimization should be substantial. Just to have a baseline, here are my timings for the Jax reference implementation (computing all host-ligand distances in double precision on the CPU using |
* Functions aren't called
* Important to compute nonbonded energies of host-ligand
* Thanks to @jfass for noticing that we could get rid of the python loop
f786505
to
cd815ab
Compare
This PR extends our Neighborlist to support construct from two sets of coordinates. This will be necessary for eventuating host-ligand interactions.
The performance is unchanged from what is currently on master.
Current Benchmarks on A4000
The changes in the PR