Commit 53c700ab authored by David M. Rogers's avatar David M. Rogers
Browse files

Improved semantics of job variable inheritance.

parent dae85ac6
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -83,8 +83,11 @@ def dict_merge(dct, merge_dct):
        if k in dct and isinstance(dct[k], dict) and isinstance(v, dict):
            dict_merge(dct[k], v)
        else:
            dct[k] = merge_dct[k]

            if isinstance(v, dict): # deep-copy
                dct[k] = {}
                dict_merge(dct[k], v)
            else:
                dct[k] = v

# Check for an MPI error in this file
def grep(fname, exp):
+18 −8
Original line number Diff line number Diff line
@@ -43,10 +43,14 @@ class Simple:
        # Fill-in from kws and type-check class vars
        for var, cls in self.reqs.items():
            if var in kws: # initialize from yaml data
                val = kws[var].copy()
                val = kws[var]
                if isinstance(val, dict) and hasattr(self, var) \
                         and isinstance(getattr(self,var), dict): # merge together
                   dict_merge( getattr(self,var), val ) # recursively merge parameter hierarchies
                else:
                    if isinstance(val, dict): # deep-copy
                        setattr(self, var, {})
                        dict_merge( getattr(self,var), val )
                    else:
                        setattr(self, var, val)
            elif hasattr(self, var):
@@ -98,7 +102,7 @@ class Simple:
            f.write(self.format(jobscript % self.script, self.next_job.jobid))
        return False

    def run(self):
    def run(self, testonly=False):
        """ Check on current state and submit next job if necessary.
        """
        state = self.get_state()
@@ -111,6 +115,8 @@ class Simple:
            print(self.next_job)
            return "started"

        if testonly:
            return "none"
        if self.setup_job(): # error creating jobscript
            return "error"
        job_n = queue_job(self.next_job.job_in)
@@ -121,7 +127,7 @@ class Simple:
class Local(Simple):
    time = "0:10"
    nodes = 1
    def run(self):
    def run(self, testonly=False):
        state = self.get_state()
        if state is not "none": # no more work
            return state
@@ -132,6 +138,8 @@ class Local(Simple):
            print(self.next_job)
            return "started"

        if testonly:
            return "none"
        if self.setup_job(): # error creating jobscript
            return "error"
        if self.next_job.exec():
@@ -169,7 +177,7 @@ class Restartable(Simple):
            f.write(self.format(jobscript % self.retry_script, self.next_job.jobid))
        return False

    def run(self):
    def run(self, testonly=False):
        """ Check on current state and submit next job if necessary.
        """
        state = self.get_state()
@@ -182,6 +190,8 @@ class Restartable(Simple):
            print(self.next_job)
            return "started"

        if testonly:
            return "none"
        if self.next_job.jobid == 0: # initial run
            mk = self.setup_job
        else:
@@ -217,17 +227,16 @@ class Chain(Simple):
        Simple.__init__(self, proj, jobname, kws)
        #self.out = self.chain[-1][1]['out']

    def run(self):
    def run(self, testonly=False):
        """ Check on current state and submit next job if necessary.
        """
        inp = self.inp
        for i, (job,args) in enumerate(self.chain):
            if 'dirname' not in args:
                args['dirname'] = self.dirname
            print(i,job.jobtype,job.reqs)
            z = job(self.proj, "%s.%d"%(self.jobname,i), args)
            z.inp.update(inp) # connect inputs and outputs
            state = z.run()
            state = z.run(testonly)
            inp = z.out
            if isinstance(state, int):
                return "started"
@@ -248,6 +257,7 @@ def read_jobtypes(filename):

    for name, info in x.items():
        bname = info['jobtype']
        info['jobtype'] = name
        base = types[bname]
        types[name] = type(name, (base,), info)

+15 −5
Original line number Diff line number Diff line
@@ -14,7 +14,10 @@ def subst_chain(jobtype, args, types):
       for i,d in enumerate(args['chain']):
            if isinstance(d, tuple): # already processed.
                continue
            elif isinstance(d, dict):
                assert len(d) == 1
            else:
                print("Invalid arg. to 'chain:'")
            for type1, args1 in d.items():
                pass
            job = subst_chain(type1, args1, types) # traverse sub-chains
@@ -23,12 +26,19 @@ def subst_chain(jobtype, args, types):

def main(argv):
    jtfile = None
    if len(argv) > 2 and argv[1] == "-j":
    testonly = False
    while len(argv) > 1 and argv[1][0] == '-':
      if argv[1] == "-j":
        jtfile = argv[2]
        del argv[1:3]
      elif argv[1] == "-t":
        testonly = True
        del argv[1]
      else:
        raise KeyError("Unknown option: %s"%argv[1])
    types = read_jobtypes(jtfile)

    assert len(argv) == 3, "Usage: %s [-j jobtypes.yaml] <jobs.yaml> <proj id>"%argv[0]
    assert len(argv) == 3, "Usage: %s [-t] [-j jobtypes.yaml] <jobs.yaml> <proj id>"%argv[0]
    with open(argv[1]) as f:
        jobs = yaml.load(f)
    proj = argv[2]
@@ -43,7 +53,7 @@ def main(argv):

        job = subst_chain(jobtype, args, types)
        job = job(proj, jobname, args)
        job.run()
        job.run(testonly)

if __name__=="__main__":
    import sys