Skip to content

Commit

Permalink
Merge pull request #95 from wiederm/update_fep
Browse files Browse the repository at this point in the history
Update FEP Protocol
  • Loading branch information
wiederm authored Feb 28, 2024
2 parents 162e386 + 6549a4a commit 63b92ab
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 57 deletions.
38 changes: 24 additions & 14 deletions endstate_correction/neq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def perform_switching(
save_endstates: bool = False,
workdir: str = ".",
) -> Tuple[list, list]:
"""Perform NEQ switching using the provided lambda schema on the passed simulation instance.
"""Perform NEQ or instantaneous switching using the provided lambda schema on the passed simulation instance.
Args:
sim (Simulation): simulation instance
Expand All @@ -37,7 +37,7 @@ def perform_switching(
RuntimeError: if the number of lambda states is less than 2
Returns:
Tuple[list, list]: work values, endstate samples
Tuple[list, list]: work or dE values, endstate samples
"""
os.makedirs(workdir, exist_ok=True)
if save_endstates:
Expand All @@ -49,28 +49,38 @@ def perform_switching(
ws = []
# list for all endstate samples (can be empty if saving is not needed)
endstate_samples = []
# list for all switching trajectories (can be empty if saving is not needed)
all_switching_trajectories = []

inst_switching = False
if len(lambdas) == 2:
print("Instantanious switching: dE will be calculated")
inst_switching = True
if nr_of_switches == -1: # if no specific nr_of_switches is provided (-1 is the default value), use all provided equilibrium samples
nr_of_switches = len(samples)
print(f"{nr_of_switches} dE values will be calculated using all provided equilibrium samples")
else:
print(f"{nr_of_switches} dE values will be calculated using {nr_of_switches} random equilibrium samples")
elif len(lambdas) < 2:
raise RuntimeError("increase the number of lambda states")
else:
print("NEQ switching: dW will be calculated")

# start with switch
for switch_index in tqdm(range(nr_of_switches)):
# select a random frame
random_frame_idx = random.randint(0, len(samples.xyz) - 1)
# select the coordinates of the random frame
coord = samples.openmm_positions(random_frame_idx)
if samples.unitcell_lengths is not None:
box_length = samples.openmm_boxes(random_frame_idx)
else:
box_length = None
if inst_switching and nr_of_switches == len(samples): # if all samples should be used for instantanious switching
coord = samples.openmm_positions(switch_index)
if samples.unitcell_lengths is not None:
box_length = samples.openmm_boxes(switch_index)
else:
box_length = None
else: # if a specific number of instantaneous switches should be calculated, random conformations will be drawn from the provided equlibirum samples
# select a random frame
random_frame_idx = random.randint(0, len(samples.xyz) - 1)
# select the coordinates of the random frame
coord = samples.openmm_positions(random_frame_idx)
if samples.unitcell_lengths is not None:
box_length = samples.openmm_boxes(random_frame_idx)
else:
box_length = None
# set position
sim.context.setPositions(coord)
if box_length is not None:
Expand Down Expand Up @@ -137,12 +147,12 @@ def perform_switching(
)
if save_endstates:
# save the endstate conformation
endstate_samples.append(get_positions(sim))
endstate_samples.append(get_positions(sim).value_in_unit(unit.nanometer))
# get all work values
ws.append(w)
return (
np.array(ws) * unit.kilojoule_per_mole,
endstate_samples,
endstate_samples
)


Expand Down
6 changes: 0 additions & 6 deletions endstate_correction/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,6 @@ class NEQResults(BaseResults):
endstate_samples_target_to_reference: np.array = field(
default_factory=lambda: np.array([])
) # endstate samples from target to reference
switching_traj_reference_to_target: np.array = field(
default_factory=lambda: np.array([])
) # switching traj from reference to target
switching_traj_target_to_reference: np.array = field(
default_factory=lambda: np.array([])
) # switching traj from target to reference


@dataclass
Expand Down
15 changes: 10 additions & 5 deletions endstate_correction/tests/test_endstate_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,6 @@ def test_each_protocol():
assert len(r.neq_results.W_target_to_reference) == 0
assert len(r.neq_results.endstate_samples_reference_to_target) == 0
assert len(r.neq_results.endstate_samples_reference_to_target) == 0
assert len(r.neq_results.switching_traj_reference_to_target) == 0
assert len(r.neq_results.switching_traj_target_to_reference) == 0
assert r.equ_results == None

neq_protocol = NEQProtocol(
Expand All @@ -353,11 +351,18 @@ def test_each_protocol():
assert len(r.neq_results.W_target_to_reference) == neq_protocol.nr_of_switches
assert len(r.neq_results.endstate_samples_reference_to_target) == 0
assert len(r.neq_results.endstate_samples_reference_to_target) == 0
assert len(r.neq_results.switching_traj_reference_to_target) == 0
assert len(r.neq_results.switching_traj_target_to_reference) == 0
assert r.equ_results == None

# test saving endstates and saving trajectory option
# if no specific number of swithes is given, the protocol should take all provided equilibrium samples
fep_protocol = FEPProtocol(
sim=sim,
reference_samples=mm_samples[:20],
)
r = perform_endstate_correction(fep_protocol)
r_fep = r.fep_results
assert len(r_fep.dE_reference_to_target) == 20

# test saving endstates option
protocol = NEQProtocol(
sim=sim,
reference_samples=mm_samples,
Expand Down
71 changes: 39 additions & 32 deletions scripts/perform_correction_hipen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from endstate_correction.constant import zinc_systems, blacklist
from endstate_correction.analysis import plot_endstate_correction_results
import endstate_correction
from endstate_correction.protocol import perform_endstate_correction, Protocol
from endstate_correction.protocol import perform_endstate_correction, FEPProtocol, NEQProtocol
import mdtraj
from openmm import unit
import pickle, sys, os
Expand All @@ -30,8 +30,7 @@

env = "vacuum"

print(system_name)
print(env)
print(f"Setting up system {system_name} in {env}")

# define directory containing parameters
parameter_base = f"{package_path}/data/hipen_data"
Expand All @@ -53,12 +52,15 @@
chains = list(psf.topology.chains())
ml_atoms = [atom.index for atom in chains[0].atoms()]
print(f"{ml_atoms=}")
print("Creating mm system...")
mm_system = psf.createSystem(params=params)
# define system
potential = MLPotential("ani2x")
print("Creating mixed system...")
ml_system = potential.createMixedSystem(
psf.topology, mm_system, ml_atoms, interpolate=True
)
print("Creating simulation object...")
sim = Simulation(psf.topology, ml_system, LangevinIntegrator(300, 1, 0.001))
########################################################
########################################################
Expand All @@ -70,71 +72,76 @@
traj_base = f"/data/shared/projects/endstate_rew/{system_name}/sampling_charmmff/"

# load MM samples
mm_samples = []
mm_samples_list = []
for i in range(1, 4):
base = f"{traj_base}/run0{i}/{system_name}_samples_{n_samples}_steps_{n_steps_per_sample}_lamb_0.0000"
# if needed, convert pickle file to dcd
# convert_pickle_to_dcd_file(f"{base}.pickle",psf_file, crd_file, f"{base}.dcd", "temp.pdb")
traj = mdtraj.load_dcd(
f"{base}.dcd",
top=psf_file,
)
mm_samples.extend(traj[1000:].xyz * unit.nanometer) # NOTE: this is in nanometer!
top=psf_file, # also possible to use the tmp.pdb
)[int((n_samples / 100) * 20):]
mm_samples_list.append(traj)

mm_samples = mdtraj.join(mm_samples_list) #* unit.nanometer
assert isinstance(mm_samples, mdtraj.Trajectory)
print(f"Initializing switch from {len(mm_samples)} MM samples")

# load NNP samples
nnp_samples = []
nnp_samples_list = []
for i in range(1, 4):
base = f"{traj_base}/run0{i}/{system_name}_samples_{n_samples}_steps_{n_steps_per_sample}_lamb_1.0000"
# if needed, convert pickle file to dcd
# convert_pickle_to_dcd_file(f"{base}.pickle",psf_file, crd_file, f"{base}.dcd", "temp.pdb")
traj = mdtraj.load_dcd(
f"{base}.dcd",
top=psf_file,
)
nnp_samples.extend(traj[1000:].xyz * unit.nanometer) # NOTE: this is in nanometer!
print(f"Initializing switch from {len(mm_samples)} NNP samples")
top=psf_file, # also possible to use the tmp.pdb
)[int((n_samples / 100) * 20):]
nnp_samples_list.append(traj)

nnp_samples = mdtraj.join(nnp_samples_list) #* unit.nanometer
assert isinstance(nnp_samples, mdtraj.Trajectory)
print(f"Initializing switch from {len(nnp_samples)} NNP samples")

########################################################
########################################################
# ----------------- perform correction ----------------#

# define the output directory
output_base = f"/data/shared/projects/endstate_rew/{system_name}/switching_new/"
output_base = f"/data/shared/projects/endstate_rew/{system_name}/FEP_v1/"
os.makedirs(output_base, exist_ok=True)

####################################################
# ---------------- FEP protocol --------------------
####################################################
fep_protocol = Protocol(
method="FEP",
direction="bidirectional",
fep_protocol = FEPProtocol(
sim=sim,
trajectories=[mm_samples, nnp_samples],
nr_of_switches=10, # 2_000,
reference_samples=mm_samples,
target_samples=nnp_samples,
nr_of_switches=2_000, # if not provided, the protocol will use all provided equilibrium samples
)

####################################################
# ----------------- NEQ protocol -------------------
####################################################
neq_protocol = Protocol(
method="NEQ",
direction="bidirectional",
sim=sim,
trajectories=[mm_samples, nnp_samples],
nr_of_switches=3, # 500,
neq_switching_length=5, # _000,
save_endstates=True,
save_trajs=True,
)
# neq_protocol = NEQProtocol(
# sim=sim,
# reference_samples=mm_samples,
# target_samples=nnp_samples,
# nr_of_switches=3, # 500,
# switching_length=5, # _000,
# save_endstates=False,
# save_trajs=False,
# )

# perform correction
r_fep = perform_endstate_correction(fep_protocol)
r_neq = perform_endstate_correction(neq_protocol)
#r_neq = perform_endstate_correction(neq_protocol)

# save fep and neq results in a pickle file
pickle.dump((r_fep, r_neq), open(f"{output_base}/results.pickle", "wb"))
#pickle.dump((r_fep, r_neq), open(f"{output_base}/results.pickle", "wb"))
pickle.dump((r_fep), open(f"{output_base}/fep_results_{system_name}.pickle", "wb"))

# plot results
plot_endstate_correction_results(system_name, r_fep, f"{output_base}/results_neq.png")
plot_endstate_correction_results(system_name, r_neq, f"{output_base}/results_neq.png")
plot_endstate_correction_results(system_name, r_fep, f"{output_base}/results_fep_{system_name}.png")
#plot_endstate_correction_results(system_name, r_neq, f"{output_base}/results_neq.png")

0 comments on commit 63b92ab

Please sign in to comment.