From 107a40c286f98ff7ab7369eb602c0d0b7df2b464 Mon Sep 17 00:00:00 2001 From: saratk1 Date: Tue, 27 Feb 2024 18:34:25 +0100 Subject: [PATCH 1/5] adjust FEP protocol: add option to use all provided equilibrium samples --- endstate_correction/neq.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/endstate_correction/neq.py b/endstate_correction/neq.py index cb086c4..e6fd111 100644 --- a/endstate_correction/neq.py +++ b/endstate_correction/neq.py @@ -23,7 +23,7 @@ def perform_switching( save_endstates: bool = False, workdir: str = ".", ) -> Tuple[list, list]: - """Perform NEQ switching using the provided lambda schema on the passed simulation instance. + """Perform NEQ or instantaneous switching using the provided lambda schema on the passed simulation instance. Args: sim (Simulation): simulation instance @@ -37,7 +37,7 @@ def perform_switching( RuntimeError: if the number of lambda states is less than 2 Returns: - Tuple[list, list]: work values, endstate samples + Tuple[list, list]: work or dE values, endstate samples """ os.makedirs(workdir, exist_ok=True) if save_endstates: @@ -49,13 +49,14 @@ def perform_switching( ws = [] # list for all endstate samples (can be empty if saving is not needed) endstate_samples = [] - # list for all switching trajectories (can be empty if saving is not needed) - all_switching_trajectories = [] inst_switching = False if len(lambdas) == 2: print("Instantanious switching: dE will be calculated") inst_switching = True + if nr_of_switches == -1: # if no specific nr_of_switches is provided (-1 is the default value), use all provided equilibrium samples + nr_of_switches = len(samples) + print(f"{nr_of_switches} dE values will be calculated") elif len(lambdas) < 2: raise RuntimeError("increase the number of lambda states") else: @@ -63,14 +64,21 @@ def perform_switching( # start with switch for switch_index in tqdm(range(nr_of_switches)): - # select a random frame - random_frame_idx = random.randint(0, len(samples.xyz) - 1) - # select the coordinates of the random frame - coord = samples.openmm_positions(random_frame_idx) - if samples.unitcell_lengths is not None: - box_length = samples.openmm_boxes(random_frame_idx) - else: - box_length = None + if inst_switching and nr_of_switches == len(samples): # if all samples should be used for instantanious switching + coord = samples.openmm_positions(switch_index) + if samples.unitcell_lengths is not None: + box_length = samples.openmm_boxes(switch_index) + else: + box_length = None + else: # if a specific number of instantaneous switches should be calculated, random conformations will be drawn from the provided equlibirum samples + # select a random frame + random_frame_idx = random.randint(0, len(samples.xyz) - 1) + # select the coordinates of the random frame + coord = samples.openmm_positions(random_frame_idx) + if samples.unitcell_lengths is not None: + box_length = samples.openmm_boxes(random_frame_idx) + else: + box_length = None # set position sim.context.setPositions(coord) if box_length is not None: @@ -137,12 +145,12 @@ def perform_switching( ) if save_endstates: # save the endstate conformation - endstate_samples.append(get_positions(sim)) + endstate_samples.append(get_positions(sim).value_in_unit(unit.nanometer)) # get all work values ws.append(w) return ( np.array(ws) * unit.kilojoule_per_mole, - endstate_samples, + endstate_samples ) From 55d9146dd569aaa6a264c9efc29e790155021303 Mon Sep 17 00:00:00 2001 From: saratk1 Date: Tue, 27 Feb 2024 18:35:54 +0100 Subject: [PATCH 2/5] delete outdated parts (switching trajectory is now directly saved to a dcd file) --- endstate_correction/protocol.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/endstate_correction/protocol.py b/endstate_correction/protocol.py index 67e4526..d218d94 100644 --- a/endstate_correction/protocol.py +++ b/endstate_correction/protocol.py @@ -224,12 +224,6 @@ class NEQResults(BaseResults): endstate_samples_target_to_reference: np.array = field( default_factory=lambda: np.array([]) ) # endstate samples from target to reference - switching_traj_reference_to_target: np.array = field( - default_factory=lambda: np.array([]) - ) # switching traj from reference to target - switching_traj_target_to_reference: np.array = field( - default_factory=lambda: np.array([]) - ) # switching traj from target to reference @dataclass From 1f2e2b5dd79941c01749efb03edaf4b0f8cd5d23 Mon Sep 17 00:00:00 2001 From: saratk1 Date: Tue, 27 Feb 2024 18:36:14 +0100 Subject: [PATCH 3/5] add test for fep protocol --- .../tests/test_endstate_correction.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/endstate_correction/tests/test_endstate_correction.py b/endstate_correction/tests/test_endstate_correction.py index d6c7682..9597e7b 100644 --- a/endstate_correction/tests/test_endstate_correction.py +++ b/endstate_correction/tests/test_endstate_correction.py @@ -335,8 +335,6 @@ def test_each_protocol(): assert len(r.neq_results.W_target_to_reference) == 0 assert len(r.neq_results.endstate_samples_reference_to_target) == 0 assert len(r.neq_results.endstate_samples_reference_to_target) == 0 - assert len(r.neq_results.switching_traj_reference_to_target) == 0 - assert len(r.neq_results.switching_traj_target_to_reference) == 0 assert r.equ_results == None neq_protocol = NEQProtocol( @@ -353,11 +351,18 @@ def test_each_protocol(): assert len(r.neq_results.W_target_to_reference) == neq_protocol.nr_of_switches assert len(r.neq_results.endstate_samples_reference_to_target) == 0 assert len(r.neq_results.endstate_samples_reference_to_target) == 0 - assert len(r.neq_results.switching_traj_reference_to_target) == 0 - assert len(r.neq_results.switching_traj_target_to_reference) == 0 assert r.equ_results == None - # test saving endstates and saving trajectory option + # if no specific number of swithes is given, the protocol should take all provided equilibrium samples + fep_protocol = FEPProtocol( + sim=sim, + reference_samples=mm_samples[:20], + ) + r = perform_endstate_correction(fep_protocol) + r_fep = r.fep_results + assert len(r_fep.dE_reference_to_target) == 20 + + # test saving endstates option protocol = NEQProtocol( sim=sim, reference_samples=mm_samples, From 66641b3c53aa71a92cadfd2725fca302d38b8ce5 Mon Sep 17 00:00:00 2001 From: saratk1 Date: Tue, 27 Feb 2024 18:36:55 +0100 Subject: [PATCH 4/5] correct loading of equilibrium samples --- scripts/perform_correction_hipen.py | 71 ++++++++++++++++------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/scripts/perform_correction_hipen.py b/scripts/perform_correction_hipen.py index 6d57fa4..86c1826 100644 --- a/scripts/perform_correction_hipen.py +++ b/scripts/perform_correction_hipen.py @@ -11,7 +11,7 @@ from endstate_correction.constant import zinc_systems, blacklist from endstate_correction.analysis import plot_endstate_correction_results import endstate_correction -from endstate_correction.protocol import perform_endstate_correction, Protocol +from endstate_correction.protocol import perform_endstate_correction, FEPProtocol, NEQProtocol import mdtraj from openmm import unit import pickle, sys, os @@ -30,8 +30,7 @@ env = "vacuum" -print(system_name) -print(env) +print(f"Setting up system {system_name} in {env}") # define directory containing parameters parameter_base = f"{package_path}/data/hipen_data" @@ -53,12 +52,15 @@ chains = list(psf.topology.chains()) ml_atoms = [atom.index for atom in chains[0].atoms()] print(f"{ml_atoms=}") +print("Creating mm system...") mm_system = psf.createSystem(params=params) # define system potential = MLPotential("ani2x") +print("Creating mixed system...") ml_system = potential.createMixedSystem( psf.topology, mm_system, ml_atoms, interpolate=True ) +print("Creating simulation object...") sim = Simulation(psf.topology, ml_system, LangevinIntegrator(300, 1, 0.001)) ######################################################## ######################################################## @@ -70,71 +72,76 @@ traj_base = f"/data/shared/projects/endstate_rew/{system_name}/sampling_charmmff/" # load MM samples -mm_samples = [] +mm_samples_list = [] for i in range(1, 4): base = f"{traj_base}/run0{i}/{system_name}_samples_{n_samples}_steps_{n_steps_per_sample}_lamb_0.0000" # if needed, convert pickle file to dcd # convert_pickle_to_dcd_file(f"{base}.pickle",psf_file, crd_file, f"{base}.dcd", "temp.pdb") traj = mdtraj.load_dcd( f"{base}.dcd", - top=psf_file, - ) - mm_samples.extend(traj[1000:].xyz * unit.nanometer) # NOTE: this is in nanometer! + top=psf_file, # also possible to use the tmp.pdb + )[int((n_samples / 100) * 20):] + mm_samples_list.append(traj) + +mm_samples = mdtraj.join(mm_samples_list) #* unit.nanometer +assert isinstance(mm_samples, mdtraj.Trajectory) print(f"Initializing switch from {len(mm_samples)} MM samples") # load NNP samples -nnp_samples = [] +nnp_samples_list = [] for i in range(1, 4): base = f"{traj_base}/run0{i}/{system_name}_samples_{n_samples}_steps_{n_steps_per_sample}_lamb_1.0000" # if needed, convert pickle file to dcd # convert_pickle_to_dcd_file(f"{base}.pickle",psf_file, crd_file, f"{base}.dcd", "temp.pdb") traj = mdtraj.load_dcd( f"{base}.dcd", - top=psf_file, - ) - nnp_samples.extend(traj[1000:].xyz * unit.nanometer) # NOTE: this is in nanometer! -print(f"Initializing switch from {len(mm_samples)} NNP samples") + top=psf_file, # also possible to use the tmp.pdb + )[int((n_samples / 100) * 20):] + nnp_samples_list.append(traj) + +nnp_samples = mdtraj.join(nnp_samples_list) #* unit.nanometer +assert isinstance(nnp_samples, mdtraj.Trajectory) +print(f"Initializing switch from {len(nnp_samples)} NNP samples") ######################################################## ######################################################## # ----------------- perform correction ----------------# # define the output directory -output_base = f"/data/shared/projects/endstate_rew/{system_name}/switching_new/" +output_base = f"/data/shared/projects/endstate_rew/{system_name}/FEP_v1/" os.makedirs(output_base, exist_ok=True) #################################################### # ---------------- FEP protocol -------------------- #################################################### -fep_protocol = Protocol( - method="FEP", - direction="bidirectional", +fep_protocol = FEPProtocol( sim=sim, - trajectories=[mm_samples, nnp_samples], - nr_of_switches=10, # 2_000, + reference_samples=mm_samples, + target_samples=nnp_samples, + nr_of_switches=2_000, # if not provided, the protocol will use all provided equilibrium samples ) #################################################### # ----------------- NEQ protocol ------------------- #################################################### -neq_protocol = Protocol( - method="NEQ", - direction="bidirectional", - sim=sim, - trajectories=[mm_samples, nnp_samples], - nr_of_switches=3, # 500, - neq_switching_length=5, # _000, - save_endstates=True, - save_trajs=True, -) +# neq_protocol = NEQProtocol( + # sim=sim, + # reference_samples=mm_samples, + # target_samples=nnp_samples, + # nr_of_switches=3, # 500, + # switching_length=5, # _000, + # save_endstates=False, + # save_trajs=False, +# ) # perform correction r_fep = perform_endstate_correction(fep_protocol) -r_neq = perform_endstate_correction(neq_protocol) +#r_neq = perform_endstate_correction(neq_protocol) # save fep and neq results in a pickle file -pickle.dump((r_fep, r_neq), open(f"{output_base}/results.pickle", "wb")) +#pickle.dump((r_fep, r_neq), open(f"{output_base}/results.pickle", "wb")) +pickle.dump((r_fep), open(f"{output_base}/fep_results_{system_name}.pickle", "wb")) # plot results -plot_endstate_correction_results(system_name, r_fep, f"{output_base}/results_neq.png") -plot_endstate_correction_results(system_name, r_neq, f"{output_base}/results_neq.png") +plot_endstate_correction_results(system_name, r_fep, f"{output_base}/results_fep_{system_name}.png") +#plot_endstate_correction_results(system_name, r_neq, f"{output_base}/results_neq.png") From 6549a4af4d33a42d83acb6ead9cbad9e9cb400bd Mon Sep 17 00:00:00 2001 From: saratk1 Date: Wed, 28 Feb 2024 11:13:09 +0100 Subject: [PATCH 5/5] make it clearer how many samples are used for FEP --- endstate_correction/neq.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/endstate_correction/neq.py b/endstate_correction/neq.py index e6fd111..ec8cae9 100644 --- a/endstate_correction/neq.py +++ b/endstate_correction/neq.py @@ -56,7 +56,9 @@ def perform_switching( inst_switching = True if nr_of_switches == -1: # if no specific nr_of_switches is provided (-1 is the default value), use all provided equilibrium samples nr_of_switches = len(samples) - print(f"{nr_of_switches} dE values will be calculated") + print(f"{nr_of_switches} dE values will be calculated using all provided equilibrium samples") + else: + print(f"{nr_of_switches} dE values will be calculated using {nr_of_switches} random equilibrium samples") elif len(lambdas) < 2: raise RuntimeError("increase the number of lambda states") else: