-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fully variable nonbonded lambda interpolation #403
Conversation
return ref_potential(x, qlj, box, lamb) | ||
|
||
|
||
for lamb in [0.0, 0.2, 1.0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@maxentile usage example and test here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a big step up -- can define protocols flexibly, and automatically get the resulting du_dl
s required for free energy calculations! This will allow us to run free energy calculations using arbitrary user-specified functions for how to vary distance offsets and nonbonded parameters throughout a protocol.
# E = 0 # DEBUG! | ||
qlj_src, ref_potential, test_potential = prepare_water_system( | ||
coords, | ||
lambda_plane_idxs, | ||
lambda_offset_idxs, | ||
p_scale=1.0, | ||
cutoff=cutoff | ||
) | ||
|
||
qlj_dst, _, _ = prepare_water_system( | ||
coords, | ||
lambda_plane_idxs, | ||
lambda_offset_idxs, | ||
p_scale=1.0, | ||
cutoff=cutoff | ||
) | ||
|
||
def transform_q(lamb): | ||
return lamb*lamb | ||
|
||
def transform_s(lamb): | ||
return jnp.sin(lamb*np.pi/2) | ||
|
||
def transform_e(lamb): | ||
return jnp.cos(lamb*np.pi/2) | ||
|
||
def transform_w(lamb): | ||
return (1-lamb*lamb) | ||
|
||
def interpolate_params(lamb, qlj_src, qlj_dst): | ||
new_q = (1-transform_q(lamb))*qlj_src[:, 0] + transform_q(lamb)*qlj_dst[:, 0] | ||
new_s = (1-transform_s(lamb))*qlj_src[:, 1] + transform_s(lamb)*qlj_dst[:, 1] | ||
new_e = (1-transform_e(lamb))*qlj_src[:, 2] + transform_e(lamb)*qlj_dst[:, 2] | ||
return jnp.stack([new_q, new_s, new_e], axis=1) | ||
|
||
def u_reference(x, params, box, lamb): | ||
d4 = cutoff*(lambda_plane_idxs + lambda_offset_idxs*transform_w(lamb)) | ||
d4 = jnp.expand_dims(d4, axis=-1) | ||
x = jnp.concatenate((x, d4), axis=1) | ||
|
||
qlj_src = params[:len(params)//2] | ||
qlj_dst = params[len(params)//2:] | ||
qlj = interpolate_params(lamb, qlj_src, qlj_dst) | ||
return ref_potential(x, qlj, box, lamb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could move these definitions out of loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
willfix
args.append("lambda*lambda") # transform q | ||
args.append("sin(lambda*PI/2)") # transform sigma | ||
args.append("cos(lambda*PI/2)") # transform epsilon | ||
args.append("1-lambda*lambda") # transform w |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This seems highly flexible, and already covers the primary use case I can imagine. (After optimizing protocols represented parametrically as protocol(lam, weights) = dot(weights, basis_expand(lam))
, we can export to a string where the optimized weights appear as literals.
Where should we document the allowable syntax for these expressions? A few specific questions:
- Are there any other global constants that can be referenced in these expressions aside from
PI
? - Is it possible to define and reuse variables here? (Something like
0.1 * x + 0.2 * x*x + 0.3 * x*x*x ; x=(1 - lambda*lambda);
) - Is it possible to reference per-atom attributes in these expressions?
- Is there a practical limit on the length of these expressions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any other global constants that can be referenced in these expressions aside from PI?
Right now no, but you can always just hard code in the constants. It woudln't hard to add in other constants though.
Is it possible to define and reuse variables here? (Something like 0.1 * x + 0.2 * xx + 0.3 * xxx ; x=(1 - lambdalambda);)
Yes absolutely, but right now the C++ code is written as return CUSTOM_EXPRESSION
; to support proper branching (eg. if statements beyond simple ternary operators). Currently it's similar to lambda expressions, but we can definitely relax this without difficulty.
Is it possible to reference per-atom attributes in these expressions?
Currently no, but if these are forcefield independent attributes, we may be able to support them without too much difficulty. We probably won't be able to do derivatives for the per-atom attributes though, since the forward-mode AD/CSD is only efficient for R^1->R^N
.
Is there a practical limit on the length of these expressions?
Anything that doesn't break the compiler :)
def transform_q(lamb): | ||
return lamb*lamb | ||
|
||
def transform_s(lamb): | ||
return jnp.sin(lamb*np.pi/2) | ||
|
||
def transform_e(lamb): | ||
return jnp.cos(lamb*np.pi/2) | ||
|
||
def transform_w(lamb): | ||
return (1-lamb*lamb) | ||
|
||
def interpolate_params(lamb, qlj_src, qlj_dst): | ||
new_q = (1-transform_q(lamb))*qlj_src[:, 0] + transform_q(lamb)*qlj_dst[:, 0] | ||
new_s = (1-transform_s(lamb))*qlj_src[:, 1] + transform_s(lamb)*qlj_dst[:, 1] | ||
new_e = (1-transform_e(lamb))*qlj_src[:, 2] + transform_e(lamb)*qlj_dst[:, 2] | ||
return jnp.stack([new_q, new_s, new_e], axis=1) | ||
|
||
def u_reference(x, params, box, lamb): | ||
d4 = cutoff*(lambda_plane_idxs + lambda_offset_idxs*transform_w(lamb)) | ||
d4 = jnp.expand_dims(d4, axis=-1) | ||
x = jnp.concatenate((x, d4), axis=1) | ||
|
||
qlj_src = params[:len(params)//2] | ||
qlj_dst = params[len(params)//2:] | ||
qlj = interpolate_params(lamb, qlj_src, qlj_dst) | ||
return ref_potential(x, qlj, box, lamb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be moved out of loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will fix
@@ -78,6 +82,10 @@ class Nonbonded : public Potential { | |||
cudaStream_t stream | |||
); | |||
|
|||
jitify::KernelInstantiation compute_w_coords_instance_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I wonder if there are other compelling opportunities to use jitify
library within timemachine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep - can of worms has been opened.
N = conf.shape[0] | ||
|
||
conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) | ||
if conf.shape[-1] == 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows nonbonded_v3
to accept conf
s of shape either (N, 3) or (N, 4)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct! I had to support being the use case where the 4D coordinates were generated outside of this function in the interpolation tests.
timemachine/cpp/src/gpu_utils.cuh
Outdated
<< nvrtcGetErrorString(result) << '\n'; \ | ||
exit(1); \ | ||
} \ | ||
} while(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this do while required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not - but this macro isn't used, so I will remove it.
lambda, | ||
N, | ||
d_perm_, | ||
d_p, | ||
d_sorted_p_, | ||
d_sorted_dp_dl_ | ||
); | ||
gpuErrchk(cudaPeekAtLastError()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is called on ln 398 after a JIT call, do we still need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
willfix
Clean-up WIP Fix du_dps Clean-up Improve default kwargs Allow lambda for w to be interpolated Add test for lambda_w interpolation AVoid hard coded source paths WIP Update paths WIP WIP Update src path WIP Trigger tests
f17b3c7
to
e228bc3
Compare
This PR implements a JIT-based compiler using regex substitution similar to OpenMM for allowing the end-user to specify arbitrary transformations of lambda on the charges, sigmas, epsilons, and w component. In particular, this is accomplished by enabling four extra strings on the constructor of the Nonbonded class, specifying the transformations that should be applied to the lambda parameter.
It roughly follows the design in #402, namely that parameters are interpolated via some arbitrary transformation of the form:
The 4D distance/softcore similary interpolates between 0 and 1*cutoff.
The neatest part of this PR is that the derivatives for
f(lambda)
is implemented using the complex step derivative. Which is numerically accurate to machine precision. The set of operators we allow is slightly less than that of lepton from OpenMM, but this has the major advantage of not requiring any symbolic analysis, as the complex step AD is nearly identical to that of the forward mode AD and is very efficient.Mechanically, this PR also separates out the computation of
w
coordinates during decoupling to a separate set of kernels.