Commit 762e70ed authored by David M. Rogers's avatar David M. Rogers
Browse files

Extracted out the per-ligand function.

parent 98c0230d
Loading
Loading
Loading
Loading

add_pdbqt_signac_mpi2.py

deleted100644 → 0
+0 −96
Original line number Diff line number Diff line
#!/usr/bin/env python3
# Parse files like ligands/PV-001952970482_1_T2.pdbqt
# present in input files like /gpfs/alpine/bif128/world-shared/runs/N108/A_C_B_Cl_I_NA_F_OA_N_S_Br_SA_HD_13types_output_p35245.tar.gz
# or files like ligands/Z1149212469_1_T1.pdbqt
# present in input files like /gpfs/alpine/bif128/world-shared/ligand_shards/A_C_F_NA_Cl_OA_N_S_Br_SA_HD_11types_output_p61791.tar.gz

import sys, re
import logging

from helpers import *
import signac

from mpi4py import MPI
from mpi4py.futures import MPICommExecutor

project = signac.get_project()

def insert_signac(lname, lshard, latoms, ltors):
    #job = project.open_job(statepoint={'real_id':name})
    #ljob = project.add_many( {'real_id':name} for name in lname )
    ljob = [ project.open_job(statepoint={'real_id':name}) for name in lname ]
    added_jobs = project.add_many(ljob)

    for job, shard, atoms, tors in zip(ljob, lshard, latoms, ltors):
        job.doc['atoms'] = atoms
        job.doc['tors']  = tors
        job.doc['shard'] = shard

    return True

def insert_loop(tname):
      with signac.pymongo_buffered(project):
        try:
            shard = shard_id(tname)
        except ValueError:
            print("Bad name encountered, {}".format(tname))
            return ("","","")
        fname = None
        k = 0
        lname, lshard, latoms, ltors = [], [], [], []
        for fname, name, f in tar_iter(tname, "pdbqt"):
            try:
                atoms, tors = get_pdbqt_info(f)
                if atoms == 0 or tors is None:
                    print("Bad molecule %s %s %d %s"%(fname,name,atoms,str(tors)))
                    if tors is None:
                        tors = 0
            except:
                print("Bad file: {}".format(fname))
                continue
            lname.append(name)
            lshard.append(shard)
            latoms.append(atoms)
            ltors.append(tors)
            k += 1
            if k%100 == 0:
                insert_signac(lname, lshard, latoms, ltors)
                lname, lshard, latoms, ltors = [], [], [], []
        if fname is None:
            print("No pdbqt files found.")
            return ("", "", "")
        if len(lname) > 0:
            insert_signac(lname, lshard, latoms, ltors)

        return (shard, fname, k)

def main(shards):
    logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
    with MPICommExecutor(MPI.COMM_WORLD, root=0) as executor:
      if executor is not None:
        return list(executor.map(insert_loop, shards))
    #for tname in mpi_tname:
    #  insert_loop(tname)

def get_pdbqt_info(f):
    atom = 0
    tors = None
    for line in f:
        try:
            if line[:6] == b"ATOM  " or line[:6] == b"HETATM":
                atom = max(atom, int(line[6:11]))
            elif line[:7] == b"TORSDOF":
                tors = int(line[7:])
        except ValueError:
            print("Error getting info. `{}`".format(line))
            raise
    return atom, tors

if __name__ == "__main__":
    print("Reading file file")
    with open("/lustre/or-hydra/cades-bsd/world-shared/bzf-ligand/mongo/ligand_shards.txt", "r") as shard_file:
        shard_data = shard_file.read()
    print("Processing")
    res = main(shard_data.strip().split("\n")[int(sys.argv[1]):])
    #for x in res:
    #    print(x)
+6 −24
Original line number Diff line number Diff line
import re
import tarfile as tf
from pathlib import Path

# FIXME: parse sys.env['SLURM_NODELIST'] = "rhea[473,509]"
def get_rdb(name=None):
    if name is None:
        host = 'localhost'
    else:
        host = open("/gpfs/alpine/proj-shared/bif128/redis/%s.server"%name).read().strip()
    import redis
    r = redis.Redis(host=host, port=6379, password="Z1908840168_2_T1", db=0)
    return r

def shard_id(fname):
    s = fname.rindex("output_")
    return fname[s+7:].split('.')[0]

# iterator over tar-file
def tar_iter(name, ext=None):
    inp = Path(name)
@@ -37,17 +24,12 @@ def tar_iter(name, ext=None):
# parse the first cluster from each included xml file.
# example cluster syntax is:
#   <cluster cluster_rank="1" lowest_binding_energy="-4.09" run="9" mean_binding_energy="-4.09" num_in_clus="1" />
import re
cl = re.compile(b'\s*<cluster cluster_rank="1" lowest_binding_energy="([^"]*)" run="([^"]*)" mean_binding_energy="([^"]*)" num_in_clus="([^"]*)" />')
def xml_to_energy(f):
    for line in f:
        m = cl.match(line)
        if m:
            try:
            return float(m[1]),int(m[2]),float(m[3]),int(m[4])
            except ValueError:
                print("error extracting energy from xml. {}".format(line))
                raise
    return None

def grep_all(f, *keys):
@@ -70,18 +52,18 @@ def parse_dlg(f):
# inputstrings are all containing "DOCKED:"
def dlg_to_confs(lines):
    confs = []
    conf = ""
    conf = []
    en = None
    for s in lines:
        conf += s[8:] + "\n"
        conf.append( s[8:].strip() )
        # DOCKED: USER    Estimated Free Energy of Binding    =  -7.85 kcal/mol
        if "Estimated Free Energy of Binding" in s:
            tok = s.replace('=','').split()
            en = float(tok[7])
        elif s[8:14] == "ENDMDL":
            confs.append( (en, conf) )
            assert en is not None
            conf = ""
            confs.append( (en, conf) )
            conf = []
            en = None
    confs.sort()
    return confs

map_fn.py

0 → 100755
+47 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3

from run import extension, map_fn

import sys, re
import logging

from helpers import *
import signac

from mpi4py import MPI
from mpi4py.futures import MPICommExecutor

project = signac.get_project()

# relies on global variables extension and map_fn (see example run.py)
def insert_loop(tname):
    with signac.pymongo_buffered(project):
        fname = None
        for fname, name, f in tar_iter(tname, extension):
            job = project.open_job(statepoint = {'real_id':name})
            try:
                map_fn(job, f)
            except:
                print("Bad file: {}".format(fname))
                continue
        if fname is None:
            print("No %s files found in %s."%(extension, tname))

def main(argv):
    logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

    assert len(argv) == 2, "Usage: %s <list of input tgz files>"
    with open(argv[1]) as shard_file:
        shard_data = shard_file.read()
    shards = shard_data.strip().split("\n")

    with MPICommExecutor(MPI.COMM_WORLD, root=0) as executor:
      if executor is not None:
        ret = list(executor.map(insert_loop, shards))

    return ret

if __name__ == "__main__":
    import sys
    main(sys.argv)

run.py

0 → 100644
+43 −0
Original line number Diff line number Diff line
from helpers import *
import numpy as np

receptor = '6WQF'
error_key = '%s_error'%receptor
score_key = '%s_score'%receptor
conf_key  = '%s_conf' %receptor
ctr_key   = '%s_ctr'  %receptor

map = lambda f, x: [f(y) for y in x]

def add_score(job, f):
    en, conf = parse_dlg(f)
    if en > 100.0 or np.isnan(en) or np.isinf(en):
        job.doc[error_key] = 0
        return

    try:
       # conf contains lines like
       # ATOM      2  C   LIG     1       7.314  -0.197  19.453
       xyz = [ map(float, [line[30:38], line[38:46], line[46:54]]) \
                for line in conf if line[:6] == "ATOM  " or line[:6] == "HETATM" ]
       x = np.array(xyz)
       l1 = np.abs( x[:,newaxis] - x ).max(2)
       l1 += np.identity(len(l1))
       if l1.min() < 0.05:
           job.doc['AD_error'] = 1
           return
    except ValueError: # couldn't parse coords.
        job.doc[error_key] = 2
        return

    job.doc[score_key] = en
    job.doc[conf_key] = "\n".join(conf)
    job.doc[ctr_key]  = np.sum(x,0)/len(x)

# extension : xvg, pdbqt
# map_fn : (job, file obj) -> ()
#     -- can run, e.g. job.doc[attr] = val
# the function to map
extension = "dlg"
map_fn = add_score

run.sh

0 → 100644
+16 −0
Original line number Diff line number Diff line
#!/bin/bash
#SBATCH -A bif128
#SBATCH -t 20
#SBATCH -N 1
#SBATCH -J map_fn
#SBATCH -o map_fn.%J

. /ccs/proj/bif128/venvs/env.sh
nodes=$SLURM_JOB_NUM_NODES
DIR=/ccs/proj/bif128/analysis/add_signac_mongo

time srun -n $((16*nodes)) -N $nodes --cpu-bind=cores $DIR/map_fn.py test.list

# when working, do:
#srun -n $((16*nodes)) -N $nodes --cpu-bind=cores $DIR/map_fn.py /gpfs/alpine/world-shared/bif128/6WQF_docked/list
Loading