Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast batched nonbonded on interaction groups, as fxn of ligand charges and epsilons #685

Merged
merged 64 commits into from
Apr 4, 2022

Conversation

maxentile
Copy link
Collaborator

@maxentile maxentile commented Apr 1, 2022

Improved Jax utilities for processing batches of snapshots, with the aim of computing a batch of nonbonded energies as an efficient differentiable function of ligand nonbonded parameters.


Main feature:

Utilities for precomputing small summary vectors of a stored trajectory, so that ligand-environment nonbonded interaction group energies as a function of ligand charges (or ligand epsilons) can be computed as dot products with the small summary vectors. This is directly inspired by the linear basis function approach developed by Levi Naden and Michael Shirts in https://pubmed.ncbi.nlm.nih.gov/26580188/ .

(Note: These utilities make reweighting virtually a no-op, cost-wise... Rough timings are ~5 microseconds to compute a reweighted hydration free energy in explicit solvent (using Jax in double-precision CPU mode)...)


Additional refactoring:

  • Refactor jax nonbonded functions to use pairs ([N, 2] array) instead of a pair of inds_l, inds_r (2 [N,] arrays)
  • Extract function for converting nonbonded exceptions/exclusions to rescale masks from tests.common: nonbonded.convert_exceptions_to_rescale_masks nonbonded.convert_exclusions_to_rescale_masks
  • Extract function to get pairs_from_interaction_groups
  • Move unused jax function rescale_coordinates to attic
    • (unsure if broken?)
  • jax.ops.index_update(x, idx, y) -> x.at[idx].set(y), x.at[idx].multiply(0.5)
    • This update was already applied in Replaces index_update with .at().set() #672 , but in resolving merge conflicts I tweaked these slightly to use named slices (params.at[:,0].multiply(0.5) -> params.at[charge_indices].multiply(0.5))
  • np.indices -> np.meshgrid (following Add NonbondedInteractionGroup potential, rename existing nonbonded potentials #578 (comment) )
  • Change the return type of reference nonbonded interaction groups (no longer returns pairs)
  • Add functions to document Lennard-Jones combining rules

Reviewer note: This branch "vmap-jax-nonbonded-ig" is a scope-reduced version of "vmap-jax-nonbonded" (omitting complexity of handling intramolecular exceptions / exclusions, etc., for inclusion in a later PR after deboggling week).

Utility function to find atom pairs that depend on ligand and contribute to nonbonded sum for each conf in a batch of confs...
Possibly save some debugging time in the future...
* compute_batch_nonbonded_exceptions
* get_intramolecular_nb_pairs_and_prefactors
* construct_batch_nonbonded_hl_as_fxn_of_ligand_nb_params
* construct_batch_nonbonded_ll_as_fxn_of_ligand_nb_params
tests/common.py Outdated
for (i, j), exc in zip(exclusion_idxs, scales[:, 1]):
lj_rescale_mask[i][j] = 1 - exc
lj_rescale_mask[j][i] = 1 - exc
charge_rescale_mask, lj_rescale_mask = nonbonded.convert_exceptions_to_rescale_masks(exclusion_idxs, scales, N)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Point of clarification, is an exception an exclusion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, this should be "exclusions" -- will rename

(I thought "exceptions" are pairs that need to be handled differently, "exclusions" are pairs that are omitted entirely (special case of exceptions), but this is not actually the convention.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed in c4077ae

@@ -546,8 +547,12 @@ def parameterize_nonbonded(self, ff_q_params, ff_lj_params):
qlj_params, nb_potential = super().parameterize_nonbonded(ff_q_params, ff_lj_params)

# halve the strength of the charge and the epsilon parameters
src_qlj_params = jnp.asarray(qlj_params).at[:, 0].multiply(0.5)
src_qlj_params = jnp.asarray(src_qlj_params).at[:, 2].multiply(0.5)
charge_indices = jax.numpy.index_exp[:, 0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason jax.numpy was used here instead of jnp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope -- good catch!

Copy link
Collaborator

@badisa badisa Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

~Also is there a reason to use the index_exp syntax over the .at syntax? ~

edit: Is there a reason to use index_exp with .at rather than do the indexing directly? Just to reuse indices and make it more explicit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be more explicit -- should have identical behavior

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially addressed in 4becd4a

@@ -1,34 +1,30 @@
from typing import Tuple

import jax
import jax.numpy as np
import numpy as onp
from jax import vmap

Array = onp.array
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Since you are changing the usage, might be good to switch this to onp.typing.NDArray?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 78ac174

traj: [T, N, D] array
boxes: diagonal [T, D, D] array (or [T] array of Nones)
charges: [N] array
ligand_indices: [N_lig] array of ints
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any limitations we want to place on ligand_indices and env_indices? Should they be disjoint?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should be disjoint (throw error if not), and ligand_indices should be smaller than env_indices (print warning if not) -- will fix

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially addressed in 762fc61 (and applied to nonbonded_v3_interaction_groups as well)

* Partially address #685 (comment)
* Also validate input for nonbonded_v3_interaction_groups
reweight_ref = jit(make_reweighter(u_batch_ref))
reweight_test = jit(make_reweighter(u_batch_test))

for _ in range(50):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test takes about 50 seconds on CI, does this need 50 iterations to be correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, on Friday it was ~6 seconds -- will track down which change made it more expensive and reduce iterations if needed

Copy link
Collaborator Author

@maxentile maxentile Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Locally ~80-90% of the iteration time is in computing reference gradients

v_ref, g_ref = value_and_grad(reweight_ref, argnums=argnum)(eps_ligand, q_ligand)
, added in response to #685 (comment) .

Will rearrange this slightly and reduce the number of iterations from 50 to 5...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 930ae0d

(in initial version of test the iterations were ~ 0.03 seconds each, but in current version the iterations are ~0.3 seconds each, so it's good to reduce the number of iterations!)

@maxentile
Copy link
Collaborator Author

Thanks for the reviews! Enabling auto-merge

@maxentile maxentile enabled auto-merge (squash) April 4, 2022 15:28
@maxentile maxentile merged commit 0c4b610 into master Apr 4, 2022
@maxentile maxentile deleted the vmap-jax-nonbonded-ig branch June 15, 2022 13:04
@maxentile maxentile mentioned this pull request Nov 18, 2022
@maxentile maxentile mentioned this pull request Dec 4, 2023
6 tasks
maxentile added a commit that referenced this pull request Dec 5, 2023
restores functionality from #685 overwritten in #931
@maxentile maxentile mentioned this pull request Dec 5, 2023
maxentile added a commit that referenced this pull request Dec 5, 2023
* restore lj_eps_prefactor special case

restores functionality from #685 overwritten in #931

* add lj eps branch in test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants