-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast batched nonbonded on interaction groups, as fxn of ligand charges and epsilons #685
Conversation
Utility function to find atom pairs that depend on ligand and contribute to nonbonded sum for each conf in a batch of confs...
Possibly save some debugging time in the future...
Completes 852fe29
* compute_batch_nonbonded_exceptions * get_intramolecular_nb_pairs_and_prefactors * construct_batch_nonbonded_hl_as_fxn_of_ligand_nb_params * construct_batch_nonbonded_ll_as_fxn_of_ligand_nb_params
Applying suggestion from #685 (comment)
Addressing #685 (comment)
Applying suggestion in #685 (comment)
tests/common.py
Outdated
for (i, j), exc in zip(exclusion_idxs, scales[:, 1]): | ||
lj_rescale_mask[i][j] = 1 - exc | ||
lj_rescale_mask[j][i] = 1 - exc | ||
charge_rescale_mask, lj_rescale_mask = nonbonded.convert_exceptions_to_rescale_masks(exclusion_idxs, scales, N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Point of clarification, is an exception an exclusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh, this should be "exclusions" -- will rename
(I thought "exceptions" are pairs that need to be handled differently, "exclusions" are pairs that are omitted entirely (special case of exceptions), but this is not actually the convention.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed in c4077ae
timemachine/fe/topology.py
Outdated
@@ -546,8 +547,12 @@ def parameterize_nonbonded(self, ff_q_params, ff_lj_params): | |||
qlj_params, nb_potential = super().parameterize_nonbonded(ff_q_params, ff_lj_params) | |||
|
|||
# halve the strength of the charge and the epsilon parameters | |||
src_qlj_params = jnp.asarray(qlj_params).at[:, 0].multiply(0.5) | |||
src_qlj_params = jnp.asarray(src_qlj_params).at[:, 2].multiply(0.5) | |||
charge_indices = jax.numpy.index_exp[:, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason jax.numpy was used here instead of jnp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope -- good catch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
~Also is there a reason to use the index_exp
syntax over the .at
syntax? ~
edit: Is there a reason to use index_exp
with .at
rather than do the indexing directly? Just to reuse indices and make it more explicit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be more explicit -- should have identical behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partially addressed in 4becd4a
timemachine/potentials/jax_utils.py
Outdated
@@ -1,34 +1,30 @@ | |||
from typing import Tuple | |||
|
|||
import jax | |||
import jax.numpy as np | |||
import numpy as onp | |||
from jax import vmap | |||
|
|||
Array = onp.array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Since you are changing the usage, might be good to switch this to onp.typing.NDArray
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in 78ac174
Addressing #685 (comment)
Addressing #685 (comment)
traj: [T, N, D] array | ||
boxes: diagonal [T, D, D] array (or [T] array of Nones) | ||
charges: [N] array | ||
ligand_indices: [N_lig] array of ints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any limitations we want to place on ligand_indices and env_indices? Should they be disjoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should be disjoint (throw error if not), and ligand_indices should be smaller than env_indices (print warning if not) -- will fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partially addressed in 762fc61 (and applied to nonbonded_v3_interaction_groups
as well)
Addressing #685 (comment)
* Partially address #685 (comment) * Also validate input for nonbonded_v3_interaction_groups
tests/test_jax_nonbonded.py
Outdated
reweight_ref = jit(make_reweighter(u_batch_ref)) | ||
reweight_test = jit(make_reweighter(u_batch_test)) | ||
|
||
for _ in range(50): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test takes about 50 seconds on CI, does this need 50 iterations to be correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh, on Friday it was ~6 seconds -- will track down which change made it more expensive and reduce iterations if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Locally ~80-90% of the iteration time is in computing reference gradients
timemachine/tests/test_jax_nonbonded.py
Line 463 in 762fc61
v_ref, g_ref = value_and_grad(reweight_ref, argnums=argnum)(eps_ligand, q_ligand) |
Will rearrange this slightly and reduce the number of iterations from 50 to 5...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in 930ae0d
(in initial version of test the iterations were ~ 0.03 seconds each, but in current version the iterations are ~0.3 seconds each, so it's good to reduce the number of iterations!)
Thanks for the reviews! Enabling auto-merge |
Improved Jax utilities for processing batches of snapshots, with the aim of computing a batch of nonbonded energies as an efficient differentiable function of ligand nonbonded parameters.
Main feature:
Utilities for precomputing small summary vectors of a stored trajectory, so that ligand-environment nonbonded interaction group energies as a function of ligand charges (or ligand epsilons) can be computed as dot products with the small summary vectors. This is directly inspired by the linear basis function approach developed by Levi Naden and Michael Shirts in https://pubmed.ncbi.nlm.nih.gov/26580188/ .
(Note: These utilities make reweighting virtually a no-op, cost-wise... Rough timings are ~5 microseconds to compute a reweighted hydration free energy in explicit solvent (using Jax in double-precision CPU mode)...)
Additional refactoring:
pairs
([N, 2] array) instead of a pair ofinds_l
,inds_r
(2 [N,] arrays)nonbonded.convert_exceptions_to_rescale_masks
nonbonded.convert_exclusions_to_rescale_masks
rescale_coordinates
to atticjax.ops.index_update(x, idx, y)
->x.at[idx].set(y)
,x.at[idx].multiply(0.5)
params.at[:,0].multiply(0.5)
->params.at[charge_indices].multiply(0.5)
)np.indices
->np.meshgrid
(following Add NonbondedInteractionGroup potential, rename existing nonbonded potentials #578 (comment) )Reviewer note: This branch "vmap-jax-nonbonded-ig" is a scope-reduced version of "vmap-jax-nonbonded" (omitting complexity of handling intramolecular exceptions / exclusions, etc., for inclusion in a later PR after deboggling week).