Skip to content

Commit

Permalink
Merge pull request #1 from domonik/main
Browse files Browse the repository at this point in the history
Bug fixes and multicore support for generate_kmer_features.py
  • Loading branch information
martin-raden authored Nov 23, 2021
2 parents 98cef8b + a0a014c commit 22ce611
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 29 deletions.
13 changes: 13 additions & 0 deletions conda-environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name:
BrainDead
channels:
- conda-forge
- defaults
- bioconda
dependencies:
- viennarna=2.4.18
- intarna
- biopython
- pandas
- scikit-learn

147 changes: 118 additions & 29 deletions src/generate_kmer_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import re
import argparse
import os.path
from multiprocessing import Pool
from tempfile import TemporaryDirectory
import pickle
from typing import List
import sys
BINDIR = os.path.dirname(os.path.realpath(__file__))
def find_kmer_hits(sequence, kmer):
return [m.start() for m in re.finditer('(?='+kmer+')', sequence)] # re with look-ahead for overlaps

def call_command (cmd):
p = subprocess.Popen(cmd,shell=True,stdin=None, stdout=PIPE)
p = subprocess.Popen(cmd,shell=True,stdin=None, stdout=PIPE, stderr=PIPE)
(result, error) = p.communicate()
if error:
raise RuntimeError("Error in calling cmd or perl script\ncmd:{}\nstdout:{}\nstderr:{}".format(cmd, result, error))
Expand Down Expand Up @@ -63,6 +67,50 @@ def is_valid_file(file_name):
return os.path.abspath(file_name)
else:
raise FileNotFoundError(os.path.abspath(file_name))

def multicore_wrapper(seq_record, args):
out_csv_str = seq_record.id
print(seq_record.id)

seq_subopt, seq_intarna = get_subopt_intarna_strs(str(seq_record.seq),
minE_subopt=args.minE_subopt,
minE_intarna=args.minE_intarna)
for kmer in kmers_list:
hseq, hsubopt, hintarna, hsubopt_intarna = find_hits_all(str(seq_record.seq),
seq_subopt,
seq_intarna,
kmer)
cseq, csubopt, cintarna, csubopt_intarna = len(hseq), len(
hsubopt), len(hintarna), len(hsubopt_intarna)
array_features = []
if "a" in args.feature_context.lower():
array_features.append(cseq)
if "s" in args.feature_context.lower():
array_features.append(csubopt)
if "h" in args.feature_context.lower():
array_features.append(cintarna)
if "u" in args.feature_context.lower():
array_features.append(csubopt_intarna)

if args.report_counts is True:
out_csv_str += ''.join([',{}'.format(f) for f in array_features])
else:
binary_hits = ['0' if c == 0 else '1' for c in array_features]
out_csv_str += "," + ','.join(binary_hits)
return out_csv_str


def write_pickled_output(files: List[str], outfile: str, csv_header: str):
with open(outfile, "w") as of:
of.write(csv_header)
for file in files:
with open(file, "rb") as handle:
data = pickle.load(handle)
of.write("\n".join(data) + "\n")
del data




if __name__ == '__main__':

Expand All @@ -72,6 +120,8 @@ def is_valid_file(file_name):

parser.add_argument('--kmers', required=True, type=str, help='List of kmers as a comma separated string e.g. \"AGG,GA,GG\"')
parser.add_argument('--fasta', required=True, type=is_valid_file, help='Sequences to extract features from as a FASTA file')
parser.add_argument('--threads', type=int, default=1, help='Number of threads used for processing (default: 1) (WARNING: threads > 1 will impair stdout prints')
parser.add_argument('--batchsize', type=int, default=10000, help='If the number of processed fasta sequences is greater than batch size batch processing will be applied. This will lower memory consumption (default: 10000)')
parser.add_argument('--report-counts', action='store_true', help='Whether to report counts as integer, default is binary nohit(0)-hit(1)'),
parser.add_argument('--out-csv', type=str, default='stdout', help='CSV File name to write counts, pass "stdout" for stdout ')
parser.add_argument('--minE-subopt', default=-3, type=int, help='Minimum free energy of the position on RNAsubopt result')
Expand All @@ -97,33 +147,72 @@ def is_valid_file(file_name):
out_csv_str += ",{}_free".format(kmer)

out_csv_str += '\n'
for r in SeqIO.parse(args.fasta, format='fasta'):
print(r.id)
out_csv_str += r.id
seq_subopt, seq_intarna = get_subopt_intarna_strs(str(r.seq), minE_subopt=args.minE_subopt, minE_intarna=args.minE_intarna)
for kmer in kmers_list:
hseq, hsubopt, hintarna, hsubopt_intarna = find_hits_all(str(r.seq),seq_subopt,seq_intarna, kmer)
print(kmer, hseq, hsubopt, hintarna, hsubopt_intarna)
cseq, csubopt, cintarna, csubopt_intarna = len(hseq), len(hsubopt), len(hintarna), len(hsubopt_intarna)
array_features = []
if "a" in args.feature_context.lower():
array_features.append(cseq)
if "s" in args.feature_context.lower():
array_features.append(csubopt)
if "h" in args.feature_context.lower():
array_features.append(cintarna)
if "u" in args.feature_context.lower():
array_features.append(csubopt_intarna)

if args.report_counts is True:
out_csv_str += ''.join([',{}'.format(f) for f in array_features])
else:
binary_hits = ['0' if c==0 else '1' for c in array_features]
out_csv_str += ","+','.join(binary_hits)
out_csv_str += '\n'
if args.threads == 1:
for r in SeqIO.parse(args.fasta, format='fasta'):
print(r.id)
out_csv_str += r.id
seq_subopt, seq_intarna = get_subopt_intarna_strs(str(r.seq), minE_subopt=args.minE_subopt, minE_intarna=args.minE_intarna)
for kmer in kmers_list:
hseq, hsubopt, hintarna, hsubopt_intarna = find_hits_all(str(r.seq),seq_subopt,seq_intarna, kmer)
print(kmer, hseq, hsubopt, hintarna, hsubopt_intarna)
cseq, csubopt, cintarna, csubopt_intarna = len(hseq), len(hsubopt), len(hintarna), len(hsubopt_intarna)
array_features = []
if "a" in args.feature_context.lower():
array_features.append(cseq)
if "s" in args.feature_context.lower():
array_features.append(csubopt)
if "h" in args.feature_context.lower():
array_features.append(cintarna)
if "u" in args.feature_context.lower():
array_features.append(csubopt_intarna)

if args.out_csv == "stdout":
print(out_csv_str)
if args.report_counts is True:
out_csv_str += ''.join([',{}'.format(f) for f in array_features])
else:
binary_hits = ['0' if c==0 else '1' for c in array_features]
out_csv_str += ","+','.join(binary_hits)
out_csv_str += '\n'

if args.out_csv == "stdout":
print(out_csv_str)
else:
with open(args.out_csv, 'w') as outfile:
outfile.write(out_csv_str)
else:
with open(args.out_csv, 'w') as outfile:
outfile.write(out_csv_str)

calls = []
for seq_record in SeqIO.parse(args.fasta, format='fasta'):
calls.append((seq_record, args))

if args.batchsize < len(calls):
tmp_dir = TemporaryDirectory(prefix="BrainDead")
files = []
batch_calls = [calls[x:x+args.batchsize] for x in range(0, len(calls), args.batchsize)]
for x, batch in enumerate(batch_calls):
with Pool(processes=args.threads) as pool:
outstrings = pool.starmap(multicore_wrapper, batch)
file = os.path.join(tmp_dir.name, f"batch_{x}.pckl")
files.append(file)
with open(file, "wb") as handle:
pickle.dump(outstrings, handle)
write_pickled_output(files=files,
outfile=args.out_csv,
csv_header=out_csv_str)


else:
with Pool(processes=args.threads) as pool:
outstrings = pool.starmap(multicore_wrapper, calls)

out_csv_str += "\n".join(outstrings) + "\n"

if args.out_csv == "stdout":
print(out_csv_str)
else:
with open(args.out_csv, 'w') as outfile:
outfile.write(out_csv_str)





0 comments on commit 22ce611

Please sign in to comment.