Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cooler not compatible with multiprocessing #441

Open
pghzeng opened this issue Nov 30, 2024 · 0 comments
Open

Cooler not compatible with multiprocessing #441

pghzeng opened this issue Nov 30, 2024 · 0 comments

Comments

@pghzeng
Copy link

pghzeng commented Nov 30, 2024

I'm using the latest version of cooler(0.9.2). For clarity, the num_processes is set to 1, the problem is the same.

# a basic example
import cooler
import torch
from torch.multiprocessing import Pool
# from multiprocessing import Pool
# from pathos.multiprocessing import ProcessingPool as Pool
import gc

def load_targets(args):
    sample_name, cool_dir, sampled_regions = args
    cool_path = f"{cool_dir}/{sample_name}.sumnorm.mcool::/"
    targets = []
    cooler_file = cooler.Cooler(cool_path)
    print(cooler_file)
    for region in sampled_regions:
        # pdb.set_trace()
        matrix = cooler_file.matrix(balance=False).fetch(region)
        targets.append(torch.tensor(matrix))
    targets = torch.stack(targets)
    print("fetched targets", targets.shape)
    # del cooler_file
    # gc.collect()
    return targets

cool_dir = '/mnt/d/cools/'
test_samples = ["K562_MboI", "endoC", "AoTCPCs", "Liver"]
test_regions = [('chr12', 93236000, 95236000), ('chr2', 219504000, 221504000)]

num_processes = 1
for i in range(3):
    print("epoch%s!!!" % i)
    with Pool(num_processes) as pool:
        targets = pool.map(load_targets, [(sample_name, cool_dir, test_regions) for sample_name in test_samples])
    targets = torch.cat(targets)
    print("Shape!!!", targets.shape)
    # del targets
    # gc.collect()

And the output:

epoch0!!!
<Cooler "K562_MboI.sumnorm.mcool::/">
fetched targets torch.Size([2, 500, 500])
<Cooler "endoC.sumnorm.mcool::/">
fetched targets torch.Size([2, 500, 500])
<Cooler "AoTCPCs.sumnorm.mcool::/">
fetched targets torch.Size([2, 500, 500])
<Cooler "Liver.sumnorm.mcool::/">
fetched targets torch.Size([2, 500, 500])
Shape!!! torch.Size([8, 500, 500])
epoch1!!!
<Cooler "K562_MboI.sumnorm.mcool::/">
<Cooler "endoC.sumnorm.mcool::/">
fetched targets torch.Size([2, 500, 500])
<Cooler "AoTCPCs.sumnorm.mcool::/">
fetched targets torch.Size([2, 500, 500])
<Cooler "Liver.sumnorm.mcool::/">
fetched targets torch.Size([2, 500, 500])

You'll find the child process stuck, loop is not forwarding and no error would be raised. The child process just awaits, if you stop the program manually, you'll see:

KeyboardInterrupt                         Traceback (most recent call last)
Cell In[1], line 39
     36 print("epoch%s!!!" % i)
     37 with Pool(num_processes) as pool:
     38     # targets = pool.starmap(load_targets, [(sample_name, cool_dir, test_regions) for sample_name in test_samples])
---> 39     targets = pool.map(load_targets, [(sample_name, cool_dir, test_regions) for sample_name in test_samples])
     40     # print(targets)
     41 targets = torch.cat(targets)

File ~/miniconda3/envs/env/lib/python3.9/multiprocessing/pool.py:364, in Pool.map(self, func, iterable, chunksize)
    359 def map(self, func, iterable, chunksize=None):
    360     '''
    361     Apply `func` to each element in `iterable`, collecting the results
    362     in a list that is returned.
    363     '''
--> 364     return self._map_async(func, iterable, mapstar, chunksize).get()

File ~/miniconda3/envs/env/lib/python3.9/multiprocessing/pool.py:765, in ApplyResult.get(self, timeout)
    764 def get(self, timeout=None):
--> 765     self.wait(timeout)
    766     if not self.ready():
    767         raise TimeoutError

File ~/miniconda3/envs/env/lib/python3.9/multiprocessing/pool.py:762, in ApplyResult.wait(self, timeout)
    761 def wait(self, timeout=None):
--> 762     self._event.wait(timeout)

File ~/miniconda3/envs/env/lib/python3.9/threading.py:581, in Event.wait(self, timeout)
    579 signaled = self._flag
    580 if not signaled:
--> 581     signaled = self._cond.wait(timeout)
    582 return signaled

File ~/miniconda3/envs/env/lib/python3.9/threading.py:312, in Condition.wait(self, timeout)
    310 try:    # restore state no matter what (e.g., KeyboardInterrupt)
    311     if timeout is None:
--> 312         waiter.acquire()
    313         gotit = True
    314     else:

KeyboardInterrupt: 

Interestingly, this only happens if you want to manipulate the results returned from pool. If you don't use these line:

targets = torch.cat(targets)
# or 
targets = torch.stack(targets)
# or some other manipulation I havn't tested

Then the loop will keep going forward and successfully ends.

I've tried several ways to figure it out, including delete the variables after using them, trying to close the file handle of cool file (although they are actually closed automatically), put the code into a if __name__ == "__main__": block, changing the method of multiprocessing from fork to spawn, use a multiprocessing lock, clone the tensors in targets before manipulating... But they are not the point. The only thing I know is that manipulation of results blocks (at least one) new child process to fetch data from cool file. As pdb cannot be used in child process, I don't know how to debug.

Luckily, I found that solution at last. Instead of using multiprocessing or torch.multiprocessing,
from pathos.multiprocessing import ProcessingPool as Pool prevents the problem. However, exploring the cause of this phenomenon is beyond my ability. I just hope this will be helpful for those trying to process cool files using multiprocessing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant