Commit 71d23678 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

setup kernelbuilder to support parameterized gates, hooked up kernel args

parent 4f3e5819
Loading
Loading
Loading
Loading
+45 −17
Original line number Diff line number Diff line
@@ -618,8 +618,8 @@ class KernelBuilder(object):
    If you do not provide a qreg argument to the constructor (py_args_dict) 
    we will assume a single qreg named q.
    """
    def __init__(self, py_args_dict : dict = {}):
        self.kernel_args = py_args_dict 
    def __init__(self,**kwargs):
        self.kernel_args = kwargs['kernel_args'] if 'kernel_args' in kwargs else {}
        # Returns list of tuples, (name, nRequiredBits, isParameterized)
        all_instructions = internal_get_all_instructions()
        self.qjit_str = ''
@@ -631,15 +631,29 @@ class KernelBuilder(object):
            n_bits = instruction[1]
            name = instruction[0]

            # Only consider instructions without parameters for now
            if not instruction[2]:
                # No parameters...
                # set it as a method on this class
            qbits_str = ','.join(['q{}'.format(i) for i in range(n_bits)])
                new_func_str = 'def {}(self, {}):\n'.format(instruction[0].lower(),qbits_str)
                qbit_str = ','.join([])
                new_func_str += self.TAB +"self.qjit_str += '    {}(".format(name)+','.join(
                    ["{}[{{}}]".format(self.qreg_name) for i in range(n_bits)])+")\\n'.format({})".format(qbits_str)
            qbits_indexed = ','.join(["{}[{{}}]".format(self.qreg_name) for i in range(n_bits)])
            new_func_str = '''def {}(self, {}, *args):
    params_str = ''
    params = []
    if len(args):
        for arg in args:
            if isinstance(arg, str):
                params.append(str(arg))
            elif isinstance(arg, tuple):
                params.append(arg[0]+'['+str(arg[1])+']')
            else:
                print('[KernelBuilder Error] Invalid parameter type.')
                exit(1)
        params_str = ','.join(params)
    if {} and len(args) == 0:
        print("[KernelBuilder Error] You are calling a parameterized instruction ({}) but have not provided any parameters")
        exit(1)
    if not params_str:
        self.qjit_str += self.TAB+'{}({})\\n'.format({})
    else:
        self.qjit_str += self.TAB+'{}({}, {{}})\\n'.format({}, params_str)
'''.format(name.lower(), qbits_str, isParameterized, name.lower(), name, qbits_indexed, qbits_str, name, qbits_indexed, qbits_str)
            # print(new_func_str)
            result = {}
            exec (new_func_str.strip(), result)
@@ -649,14 +663,28 @@ class KernelBuilder(object):
        self.qjit_str += self.TAB + 'for i in range({}.size()):\n'.format(self.qreg_name)
        self.qjit_str += self.TAB+self.TAB+'Measure({}[i])\n'.format(self.qreg_name)

    # def measure(self, qbits):

    def create(self):
        # print(self.qjit_str)
        allowed_type_map = {'<class \'_pyqcor.qreg\'>': 'qreg',
                                     '<class \'float\'>': 'float', 'typing.List[float]': 'List[float]',
                                     '<class \'int\'>': 'int', 'typing.List[int]': 'List[int]',
                                     '<class \'_pyxacc.quantum.PauliOperator\'>': 'PauliOperator',
                                     '<class \'_pyxacc.quantum.FermionOperator\'>': 'FermionOperator',
                                     'typing.List[typing.Tuple[int, int]]': 'List[Tuple[int,int]]',
                                     'typing.List[_pyxacc.quantum.PauliOperator]': 'List[PauliOperator]',
                                     'typing.List[_pyxacc.quantum.FermionOperator]': 'List[FermionOperator]'}

        kernel_name = '__internal_qjit_kernelbuilder_kernel_'+str(uuid.uuid4()).replace('-','_')
        if inspect.stack()[-1].code_context is not None:
            kernel_name = inspect.stack()[-1].code_context[0].split(' = ')[0]

        args_str = 'q : qreg, ' + ', '.join(k+' : '+allowed_type_map[str(v)] for k,v in self.kernel_args.items())
        func = 'def {}({}):\n'.format(kernel_name, args_str)+self.qjit_str
        # print(func)
        result = globals()
        exec('def {}(q : qreg):\n'.format(kernel_name)+self.qjit_str, result)
        exec(func, result)
        # print(result)
        function = result[kernel_name]
        _qjit = qjit(function, __internal_fbody_src_provided__ = self.qjit_str)