Commit 3c13feec authored by David M. Rogers's avatar David M. Rogers
Browse files

Multiprocessing rescoring module.

parent 5f593c56
......@@ -4,7 +4,7 @@
#SBATCH --cpus-per-task 2
#SBATCH --gres gpu:1
#SBATCH -J dock
#SBATCH -o %x.%A_%a.out
#SBATCH -o %x.%A_%a.%j.out
#SBATCH --array=1-866
# TODO: add date/time to output filename
......
......@@ -10,7 +10,7 @@ rules = yaml.safe_load(open(base / 'rules.yaml'))
bucket = 'gs://ccddc'
test = False
hopper = True
hopper = False
conn_retries = 0
......@@ -111,7 +111,8 @@ def requeue(assigned, host, db):
if item is None:
break
r.sadd('ready', item)
#r.smove(assigned, 'ready', item)
if hopper:
r.incr('hopper')
print("%s %s re-queued %s."%(stamp(), assigned, item))
else:
raise IndexError("More than 10 items assigned to %s!"%assigned)
......
#!/usr/bin/env python3
import time
from multiprocessing import Process, Queue, Event
from queue import Empty
class WorkQueue:
def __init__(self, end, producers=0):
self.q = Queue()
self.end = end
self.producers = producers
self.done = 0
def __iter__(self):
return self
def __next__(self):
# Solves the signaling problem of producers alerting
# consumers by having each consumer count the number of
# "completed" producers.
# Each completed producer adds a 'None' sentinal to the queue.
while not self.end.is_set():
try:
x = self.q.get(timeout=30)
except Empty:
continue
if x is None:
self.done += 1
if self.done == self.producers:
break
continue
while self.done > 0: # put all sentinals back
self.q.put(None)
self.done -= 1
return x
for i in range(self.producers):
self.q.put(None)
raise StopIteration
# Three methods called by producers
def register(self):
self.producers += 1
def put(self, x):
self.q.put(x)
def fin(self):
self.q.put(None)
def summary(times):
if len(times) == 0:
return (0,0,0)
return min(times), max(times), sum(times)/len(times)
class SignalObject:
MAX_TERMINATE_CALLED = 3
def __init__(self, shutdown_event):
self.terminate_called = 0
self.shutdown_event = shutdown_event
import functools, signal
def default_signal_handler(
signal_object,
signal_num,
current_stack_frame):
signal_object.terminate_called += 1
signal_object.shutdown_event.set()
if signal_object.terminate_called == signal_object.MAX_TERMINATE_CALLED:
raise Exception("Program terminated via signal.")
def init_signal(signal_num, signal_object, handler):
handler = functools.partial(handler, signal_object)
signal.signal(signal_num, handler)
signal.siginterrupt(signal_num, False)
def init_signals(shutdown_event, int_handler=default_signal_handler, term_handler=default_signal_handler):
signal_object = SignalObject(shutdown_event)
init_signal(signal.SIGINT, signal_object, int_handler)
init_signal(signal.SIGTERM, signal_object, term_handler)
return signal_object
class Worker(Process):
""" Meant to be sub-classed so that you can define
your own setup() and fn() methods
"""
def __init__(self, inp, out):
Process.__init__(self)
self.inp = inp
self.out = out
self.out.register()
def setup(self):
pass
def run(self):
times = []
init_signals(self.inp.end)
self.setup()
for x in self.inp:
t0 = time.time()
try:
y = self.fn(x)
except Exception as e:
self.inp.end.set() # queue Jim Morrison
raise e
times.append(time.time() - t0)
self.out.put(y)
#print("%s: %.3f %.3f %.3f" % ((type(self).__name__,) + summary(times)))
self.out.fin()
print("%s: %.3f %.3f %.3f" % ((type(self).__name__,) + summary(times)))
# 1. parse to mol
# and calculate descriptors
# 2. RF-3 score
# 3. RF-VS 2
class Incrementor(Worker):
def fn(self, x):
time.sleep(0.1)
return x+1
def main():
end = Event()
inp = WorkQueue(end, 1)
out = WorkQueue(end)
worker = Incrementor(inp, out)
worker2 = Incrementor(inp, out)
# out now has 2 registered producers
worker.start()
worker2.start()
for i in range(10):
inp.put(i)
inp.fin()
for ret in out:
print(ret)
worker.join()
worker2.join()
if __name__ == '__main__':
main()
......@@ -5,18 +5,17 @@ import os, concurrent, subprocess
import pandas as pd
import numpy as np
from q2 import Event, Worker, WorkQueue, time
import oddt
from oddt.scoring import descriptors
from oddt.scoring.functions import RFScore
from oddt.scoring.models.regressors import randomforest
def fhash(x):
return (48271*x)%2147483647
def ihash(y):
return (1899818559*y)%2147483647
threads = 1
batch_sz = 16
threads = 33
batch_sz = 648
def gsutil(cmd):
args = ["gsutil", "-o", "GSUtil:parallel_process_count=1"
......@@ -28,109 +27,174 @@ def gsutil(cmd):
def process_inp(r, name):
n = ihash( int(name, 16) )
inp = "10344a.pq 11ad68.pq 132686.pq 16d551.pq d420e.pq 10f0d9.pq 1269f7.pq 1618c2.pq 1791e0.pq dfe9d.pq".split()
#inp = [ "%x.pq" % fhash(n+i) for i in range(batch_sz) ]
inp2 = [ "gs://ccddc/%s_docked/"%r + i for i in inp ]
inp = [ "%x.pq" % fhash(n+i) for i in range(batch_sz) ]
inp2 = [ "gs://ccddc/%s_docked/"%r + i for i in inp ]
gsutil(['cp'] + inp2 + ['./'])
#with concurrent.futures.ProcessPoolExecutor() as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
ans = executor.map(rescore, inp)
return pd.concat(ans)
receptor = None
end = Event()
start = WorkQueue(end, 1)
out1 = WorkQueue(end)
out2 = WorkQueue(end)
done = WorkQueue(end)
n_loaders = threads - 2
loaders = []
for i in range(n_loaders):
loaders.append( LoadMol(start, out1) )
loaders[-1].r = r
loaders[-1].start()
rf3 = Scorer(out1, out2)
rf3.name = "rf3"
rf3.model = "/apps/data/RFScore_v3_pdbbind2016.pickle"
rf3.version = 3
rf3.start()
dude2 = Scorer(out2, done)
dude2.name = "vs_dude_v2"
dude2.model = "/apps/data/RFScoreVS_v2_dude.pickle"
dude2.version = 2
dude2.start()
for i in inp:
start.put(i)
start.fin()
ans = [ df for df in done ]
if len(ans) > 0:
ans = pd.concat(ans)
else:
ans = pd.DataFrame()
ans.to_parquet(name+'.pq', compression='snappy', engine='pyarrow')
end.set()
return stop_procs(loaders + [rf3, dude2])
def main(argv):
global receptor
global threads
global batch_sz
if len(argv) >= 3 and argv[1] == "-n":
batch_sz = int(argv[2])
threads = batch_sz+2
del argv[1:3]
assert len(argv) == 3, "Usage: %s <receptor id> <lig id>"
# set up descriptors
receptor = next(oddt.toolkit.readfile('pdbqt', argv[1]+'.pdbqt'))
result = process_inp(argv[1], argv[2])
result.to_parquet(argv[2]+'.pq',
compression='snappy', engine='pyarrow')
def get_descriptors(receptor, confs):
cutoff = 12
ligand_atomic_nums = [6, 7, 8, 9, 15, 16, 17, 35, 53]
protein_atomic_nums = [6, 7, 8, 16]
cc = oddt.scoring.descriptors.close_contacts_descriptor(
receptor,
cutoff=cutoff,
protein_types=protein_atomic_nums,
ligand_types=ligand_atomic_nums)
vina_scores = ['vina_gauss1',
'vina_gauss2',
'vina_repulsion',
'vina_hydrophobic',
'vina_hydrogen',
'vina_num_rotors']
vina = oddt.scoring.descriptors.oddt_vina_descriptor(receptor, vina_scores=vina_scores)
#descriptors_v1 = cc
#descriptors_v2 = oddt.scoring.descriptors.close_contacts_descriptor(
# receptor,
# cutoff=np.array([0, 2, 4, 6, 8, 10, 12]),
# protein_types=protein_atomic_nums,
# ligand_types=ligand_atomic_nums)
descriptors_v3 = oddt.scoring.ensemble_descriptor((vina, cc))
return [ descriptors_v3.build( oddt.toolkit.readstring('pdbqt', x) ).reshape(-1)
for x in confs ]
# load models
models = [
('rf3', '/apps/data/RFScore_v3_pdbbind2016.pickle' )
, ('dude3', '/apps/data/RFScoreVS_v3_dude.pickle' )
, ('dock3', '/apps/data/RFScoreVS_v3_dock.pickle' )
, ('vina3', '/apps/data/RFScoreVS_v3_vina.pickle' )
]
# parallel load all these pickles
#with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
models = dict( executor.map(lambda m: (m[0], RFScore.rfscore.load(m[1], version=3)),
models) )
#models = dict(
# rf3 = RFScore.rfscore.load(
# '/apps/data/RFScore_v3_pdbbind2016.pickle', version=3)
## , vs_dude_v1 = RFScore.rfscore.load(
## 'RFScoreVS_v1_dude.pickle',version=1)
## , vs_dude_v2 = RFScore.rfscore.load(
## 'RFScoreVS_v2_dude.pickle',version=2)
# , vs_dude_v3 = RFScore.rfscore.load(
# '/apps/data/RFScoreVS_v3_dude.pickle',version=3)
## , vs_dock_v1 = RFScore.rfscore.load(
## 'RFScoreVS_v1_dock.pickle',version=1)
## , vs_dock_v2 = RFScore.rfscore.load(
## 'RFScoreVS_v2_dock.pickle',version=2)
# , vs_dock_v3 = RFScore.rfscore.load(
# '/apps/data/RFScoreVS_v3_dock.pickle',version=3)
## , vs_vina_v1 = RFScore.rfscore.load(
## 'RFScoreVS_v1_vina.pickle',version=1)
## , vs_vina_v2 = RFScore.rfscore.load(
## 'RFScoreVS_v2_vina.pickle',version=2)
# , vs_vina_v3 = RFScore.rfscore.load(
# '/apps/data/RFScoreVS_v3_vina.pickle',version=3)
#)
def rescore(inp):
print("Rescoring %s"%inp)
df = pd.read_parquet(inp)
os.remove(inp)
columns = [ 'rf3'
, 'dude3'
, 'dock3'
, 'vina3'
]
dvs = get_descriptors(receptor, df['conf'].values)
data = df['score']
for c in columns:
data[c] = models[c].model.predict(dvs)
return data
status = process_inp(argv[1], argv[2])
print(status)
class LoadMol(Worker):
""" load a molecule and calculate its descriptors """
def setup(self):
t0 = time.time()
from oddt.scoring import descriptors
# set up descriptors
receptor = next(oddt.toolkit.readfile('pdbqt',
self.r+'.pdbqt'))
cutoff = 12
ligand_atomic_nums = [6, 7, 8, 9, 15, 16, 17, 35, 53]
protein_atomic_nums = [6, 7, 8, 16]
self.v2 = descriptors.close_contacts_descriptor(
receptor,
cutoff=np.array([0, 2, 4, 6, 8, 10, 12]),
protein_types=protein_atomic_nums,
ligand_types=ligand_atomic_nums)
cc = descriptors.close_contacts_descriptor(
receptor,
cutoff=cutoff,
protein_types=protein_atomic_nums,
ligand_types=ligand_atomic_nums)
#v1 = cc
vina_scores = ['vina_gauss1',
'vina_gauss2',
'vina_repulsion',
'vina_hydrophobic',
'vina_hydrogen',
'vina_num_rotors']
vina = descriptors.oddt_vina_descriptor(receptor,
vina_scores=vina_scores)
self.v3 = oddt.scoring.ensemble_descriptor((vina, cc))
dt = time.time() - t0
print("LoadMol setup done in %f seconds"%dt)
def fn(self, inp):
try:
df = pd.read_parquet(inp)
os.remove(inp)
except FileNotFoundError:
print("Error: Input file %s is missing!"%inp)
return pd.DataFrame()
v2 = self.v2
v3 = self.v3
for x in ['', '2', '3']:
confs = df['conf'+x]
mols = []
for c in confs:
try:
m = oddt.toolkit.readstring('pdbqt', c)
except Exception:
m = None
mols.append(m)
#mols = [ oddt.toolkit.readstring('pdbqt', c) for c in confs ]
#df['vs_dude_v2'+x] = list(v2.build( mols ))
#df['rf3'+x] = list(v3.build( mols ))
df['vs_dude_v2'+x] = [ v2.build( m ).reshape(-1) if m is not None else None for m in mols ]
df['rf3'+x] = [ v3.build( m ).reshape(-1) if m is not None else None for m in mols ]
#if 'Z1509820766_1_T1' in df['name']:
# print( mols[:10] )
# print( df.head() )
return df.drop(columns=['conf', 'conf2', 'conf3'])
class Scorer(Worker):
def setup(self):
t0 = time.time()
from oddt.scoring.functions import RFScore
rfs = RFScore.rfscore.load(self.model, version=self.version)
self.score = rfs.model.predict
dt = time.time() - t0
print("Completed setup of %s in %.3f seconds"%(self.name,dt))
def fn(self, df):
if len(df) == 0:
return df
for x in ['', '2', '3']:
c = self.name + x
#df[c] = self.score( list(df[c].values) )
v = list( df[c].dropna() )
if len(v) == 0:
print("WARNING: Detected empty ligand file!")
df.loc[df[c].notna(), c] = []
else:
df.loc[df[c].notna(), c] = self.score(v)
return df
def stop_procs(procs):
end_time = time.time() + 200 # seconds (be sure they're done)
num_terminated = 0
num_failed = 0
for proc in procs:
join_secs = max(0.01, end_time - time.time())
proc.join(join_secs)
# terminate any procs that still have not exited.
for proc in procs:
if proc.is_alive():
proc.terminate()
num_terminated += 1
else:
exitcode = proc.exitcode
if exitcode:
num_failed += 1
return "%d tasks complete: %d failed, %d terminated"%(len(procs),num_failed, num_terminated)
if __name__=="__main__":
import sys
......
#!/bin/bash
#SBATCH -p rescore
#SBATCH --nodes 1
#SBATCH -n64
#SBATCH -c 64
#SBATCH -J rescore
#SBATCH -o %x.%A_%a.out
#SBATCH --array=1-1
echo "Starting $SLURM_JOB_NAME-$SLURM_ARRAY_TASK_ID at" `date`
source /apps/dock_env/env.sh
export OMP_NUM_THREADS=1
#export OMP_NUM_THREADS=1
#source /apps/dock_env/env.sh
eval "$(/apps/anaconda3/bin/conda shell.bash hook)"
conda activate rescore
conda activate rescore2
DIR=/apps/launchad
cd /dev/shm
......
......@@ -21,20 +21,8 @@ dock:
rm -f *.xml
rm -f *.dlg
# uses output of rescore
# 100k ligands per file
# makes number of files 10k files
combine: # combine 10:1 again ~
[]
# Re-score ligand/receptor conf.
# uses output of combine
# ? rescore all 3 conf?
# - remove "far" ligands
# - combine "close" ligands
# Note: this re-combines files 10:1
# creating output files that span a sequence
# 10k ligands file, 100k files
# Note: this re-combines files 648:1
rescore:
queue: rescore
db: 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment