Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fully variable nonbonded lambda interpolation #403

Merged
merged 2 commits into from
May 3, 2021

Conversation

proteneer
Copy link
Owner

@proteneer proteneer commented Apr 27, 2021

This PR implements a JIT-based compiler using regex substitution similar to OpenMM for allowing the end-user to specify arbitrary transformations of lambda on the charges, sigmas, epsilons, and w component. In particular, this is accomplished by enabling four extra strings on the constructor of the Nonbonded class, specifying the transformations that should be applied to the lambda parameter.

It roughly follows the design in #402, namely that parameters are interpolated via some arbitrary transformation of the form:

p(lambda) = (1-f(lambda))*p_src + f(lambda)*p_dst, s.t. f(0)=0 and f(1)=1

The 4D distance/softcore similary interpolates between 0 and 1*cutoff.

The neatest part of this PR is that the derivatives for f(lambda) is implemented using the complex step derivative. Which is numerically accurate to machine precision. The set of operators we allow is slightly less than that of lepton from OpenMM, but this has the major advantage of not requiring any symbolic analysis, as the complex step AD is nearly identical to that of the forward mode AD and is very efficient.

Mechanically, this PR also separates out the computation of w coordinates during decoupling to a separate set of kernels.

return ref_potential(x, qlj, box, lamb)


for lamb in [0.0, 0.2, 1.0]:
Copy link
Owner Author

@proteneer proteneer Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxentile usage example and test here

@proteneer proteneer marked this pull request as ready for review April 27, 2021 15:13
@proteneer proteneer requested review from maxentile and badisa April 27, 2021 15:13
Copy link
Collaborator

@maxentile maxentile left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a big step up -- can define protocols flexibly, and automatically get the resulting du_dls required for free energy calculations! This will allow us to run free energy calculations using arbitrary user-specified functions for how to vary distance offsets and nonbonded parameters throughout a protocol.

Comment on lines 116 to 159
# E = 0 # DEBUG!
qlj_src, ref_potential, test_potential = prepare_water_system(
coords,
lambda_plane_idxs,
lambda_offset_idxs,
p_scale=1.0,
cutoff=cutoff
)

qlj_dst, _, _ = prepare_water_system(
coords,
lambda_plane_idxs,
lambda_offset_idxs,
p_scale=1.0,
cutoff=cutoff
)

def transform_q(lamb):
return lamb*lamb

def transform_s(lamb):
return jnp.sin(lamb*np.pi/2)

def transform_e(lamb):
return jnp.cos(lamb*np.pi/2)

def transform_w(lamb):
return (1-lamb*lamb)

def interpolate_params(lamb, qlj_src, qlj_dst):
new_q = (1-transform_q(lamb))*qlj_src[:, 0] + transform_q(lamb)*qlj_dst[:, 0]
new_s = (1-transform_s(lamb))*qlj_src[:, 1] + transform_s(lamb)*qlj_dst[:, 1]
new_e = (1-transform_e(lamb))*qlj_src[:, 2] + transform_e(lamb)*qlj_dst[:, 2]
return jnp.stack([new_q, new_s, new_e], axis=1)

def u_reference(x, params, box, lamb):
d4 = cutoff*(lambda_plane_idxs + lambda_offset_idxs*transform_w(lamb))
d4 = jnp.expand_dims(d4, axis=-1)
x = jnp.concatenate((x, d4), axis=1)

qlj_src = params[:len(params)//2]
qlj_dst = params[len(params)//2:]
qlj = interpolate_params(lamb, qlj_src, qlj_dst)
return ref_potential(x, qlj, box, lamb)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could move these definitions out of loop

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

willfix

Comment on lines +169 to +172
args.append("lambda*lambda") # transform q
args.append("sin(lambda*PI/2)") # transform sigma
args.append("cos(lambda*PI/2)") # transform epsilon
args.append("1-lambda*lambda") # transform w
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This seems highly flexible, and already covers the primary use case I can imagine. (After optimizing protocols represented parametrically as protocol(lam, weights) = dot(weights, basis_expand(lam)), we can export to a string where the optimized weights appear as literals.

Where should we document the allowable syntax for these expressions? A few specific questions:

  • Are there any other global constants that can be referenced in these expressions aside from PI?
  • Is it possible to define and reuse variables here? (Something like 0.1 * x + 0.2 * x*x + 0.3 * x*x*x ; x=(1 - lambda*lambda);)
  • Is it possible to reference per-atom attributes in these expressions?
  • Is there a practical limit on the length of these expressions?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any other global constants that can be referenced in these expressions aside from PI?

Right now no, but you can always just hard code in the constants. It woudln't hard to add in other constants though.

Is it possible to define and reuse variables here? (Something like 0.1 * x + 0.2 * xx + 0.3 * xxx ; x=(1 - lambdalambda);)

Yes absolutely, but right now the C++ code is written as return CUSTOM_EXPRESSION; to support proper branching (eg. if statements beyond simple ternary operators). Currently it's similar to lambda expressions, but we can definitely relax this without difficulty.

Is it possible to reference per-atom attributes in these expressions?

Currently no, but if these are forcefield independent attributes, we may be able to support them without too much difficulty. We probably won't be able to do derivatives for the per-atom attributes though, since the forward-mode AD/CSD is only efficient for R^1->R^N.

Is there a practical limit on the length of these expressions?

Anything that doesn't break the compiler :)

Comment on lines 133 to 159
def transform_q(lamb):
return lamb*lamb

def transform_s(lamb):
return jnp.sin(lamb*np.pi/2)

def transform_e(lamb):
return jnp.cos(lamb*np.pi/2)

def transform_w(lamb):
return (1-lamb*lamb)

def interpolate_params(lamb, qlj_src, qlj_dst):
new_q = (1-transform_q(lamb))*qlj_src[:, 0] + transform_q(lamb)*qlj_dst[:, 0]
new_s = (1-transform_s(lamb))*qlj_src[:, 1] + transform_s(lamb)*qlj_dst[:, 1]
new_e = (1-transform_e(lamb))*qlj_src[:, 2] + transform_e(lamb)*qlj_dst[:, 2]
return jnp.stack([new_q, new_s, new_e], axis=1)

def u_reference(x, params, box, lamb):
d4 = cutoff*(lambda_plane_idxs + lambda_offset_idxs*transform_w(lamb))
d4 = jnp.expand_dims(d4, axis=-1)
x = jnp.concatenate((x, d4), axis=1)

qlj_src = params[:len(params)//2]
qlj_dst = params[len(params)//2:]
qlj = interpolate_params(lamb, qlj_src, qlj_dst)
return ref_potential(x, qlj, box, lamb)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be moved out of loop

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix

@@ -78,6 +82,10 @@ class Nonbonded : public Potential {
cudaStream_t stream
);

jitify::KernelInstantiation compute_w_coords_instance_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I wonder if there are other compelling opportunities to use jitify library within timemachine

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - can of worms has been opened.

N = conf.shape[0]

conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff)
if conf.shape[-1] == 3:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows nonbonded_v3 to accept confs of shape either (N, 3) or (N, 4)?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct! I had to support being the use case where the 4D coordinates were generated outside of this function in the interpolation tests.

<< nvrtcGetErrorString(result) << '\n'; \
exit(1); \
} \
} while(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this do while required?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not - but this macro isn't used, so I will remove it.

lambda,
N,
d_perm_,
d_p,
d_sorted_p_,
d_sorted_dp_dl_
);
gpuErrchk(cudaPeekAtLastError());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called on ln 398 after a JIT call, do we still need this?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

willfix

@proteneer proteneer mentioned this pull request May 3, 2021
Clean-up

WIP

Fix du_dps

Clean-up

Improve default kwargs

Allow lambda for w to be interpolated

Add test for lambda_w interpolation

AVoid hard coded source paths

WIP

Update paths

WIP

WIP

Update src path

WIP

Trigger tests
@proteneer proteneer force-pushed the lambda_improvements_2 branch from f17b3c7 to e228bc3 Compare May 3, 2021 18:27
@proteneer proteneer merged commit 0cda14c into master May 3, 2021
@proteneer proteneer deleted the lambda_improvements_2 branch November 8, 2022 01:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants