Commit 747ff814 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

removing requirement for depends_on in python decompose

parent d2808b5f
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ def ansatz(q : qreg, x : List[float]):
    # here todense() maps the sparse operator to a numpy.matrix.
    # Note that if your matrix is dependent on a kernel argument, 
    # you must define it in the depends_on=[..] decompose arg.
    with decompose(q, kak, depends_on=[x]) as u:
    with decompose(q, kak) as u:
        from scipy.sparse.linalg import expm
        from openfermion.ops import QubitOperator
        from openfermion.transforms import get_sparse_operator
+1 −1
Original line number Diff line number Diff line
@@ -56,7 +56,7 @@ print(random_1qbit.extract_composite(q).toString())
@qjit
def ansatz(q : qreg, x : List[float]):
    X(q[0])
    with decompose(q, kak, depends_on=[x]) as u:
    with decompose(q, kak) as u:
        from scipy.sparse.linalg import expm
        from openfermion.ops import QubitOperator
        from openfermion.transforms import get_sparse_operator
+101 −49
Original line number Diff line number Diff line
@@ -6,10 +6,11 @@ if '@QCOR_APPEND_PLUGIN_PATH@':
import xacc

from _pyqcor import *
import inspect
import inspect, ast
from typing import List
import typing
import re, itertools
import typing, types
import re
import itertools
from collections import defaultdict

List = typing.List
@@ -17,6 +18,7 @@ PauliOperator = xacc.quantum.PauliOperator
FLOAT_REF = typing.NewType('value', float)
INT_REF = typing.NewType('value', int)


def X(idx):
    return xacc.quantum.PauliOperator({idx: 'X'}, 1.0)

@@ -28,6 +30,7 @@ def Y(idx):
def Z(idx):
    return xacc.quantum.PauliOperator({idx: 'Z'}, 1.0)


cpp_matrix_gen_code = '''#include <pybind11/embed.h>
#include <pybind11/stl.h>
#include <pybind11/complex.h>
@@ -44,6 +47,8 @@ auto __internal__qcor_pyjit_gen_{}_unitary_matrix({}) {{
}}'''

# Simple graph class to help resolve kernel dependency (via topological sort)


class KernelGraph(object):
    def __init__(self):
        self.graph = defaultdict(list)
@@ -174,13 +179,17 @@ class qjit(object):
            lines = fbody_src.split('\n')

            # Get all lines that are 'with decompose...'
            with_decomp_lines = [line for line in lines if 'with decompose' in line if line.lstrip()[0] != '#']
            with_decomp_lines = [
                line for line in lines if 'with decompose' in line if line.lstrip()[0] != '#']
            # Get their index in the lines list
            with_decomp_lines_idxs = [lines.index(s) for s in with_decomp_lines]
            with_decomp_lines_idxs = [
                lines.index(s) for s in with_decomp_lines]
            # Get their column start integer
            with_decomp_lines_col_starts = [sum(1 for _ in itertools.takewhile(str.isspace,s)) for s in with_decomp_lines]
            with_decomp_lines_col_starts = [
                sum(1 for _ in itertools.takewhile(str.isspace, s)) for s in with_decomp_lines]
            # Get the name of the matrix we are decomposing
            with_decomp_matrix_names = [line.split(' ')[-1][:-1] for line in with_decomp_lines]
            with_decomp_matrix_names = [line.split(
                ' ')[-1][:-1] for line in with_decomp_lines]

            # Loop over all decompose segments
            for i, line_idx in enumerate(with_decomp_lines_idxs):
@@ -189,14 +198,16 @@ class qjit(object):
                # Get all lines in the with decompose scope
                # ends if we hit a line with column dedent
                for line in lines[line_idx+1:]:
                    col_loc = sum(1 for _ in itertools.takewhile(str.isspace,line))
                    col_loc = sum(
                        1 for _ in itertools.takewhile(str.isspace, line))
                    if col_loc == with_decomp_lines_col_starts[i]:
                        break
                    total_decompose_code += '\n' + line
                    stmts_to_run.append(line.lstrip())

                # Get decompose args
                decompose_args = re.search('\(([^)]+)', with_decomp_lines[i]).group(1)
                decompose_args = re.search(
                    '\(([^)]+)', with_decomp_lines[i]).group(1)
                d_list = decompose_args.split(',')
                decompose_args = d_list[0]
                for e in d_list[1:]:
@@ -204,30 +215,60 @@ class qjit(object):
                        break
                    decompose_args += ',' + e


                # Build up the matrix generation code
                code_to_exec = 'import numpy as np\n' + '\n'.join([s for s in stmts_to_run])
                code_to_exec += '\nmat_data = np.array(' + with_decomp_matrix_names[i]+').flatten()\n'
                code_to_exec += 'mat_size = '+ with_decomp_matrix_names[i]+'.shape[0]\n'
                code_to_exec = 'import numpy as np\n' + \
                    '\n'.join([s for s in stmts_to_run])
                code_to_exec += '\nmat_data = np.array(' + \
                    with_decomp_matrix_names[i]+').flatten()\n'
                code_to_exec += 'mat_size = ' + \
                    with_decomp_matrix_names[i]+'.shape[0]\n'
                # Users can use numpy. or np.
                code_to_exec = code_to_exec.replace('numpy.', 'np.')

                if 'depends_on' in with_decomp_lines[i]:
                # Figure out it code_to_exec depends on any 
                # kernel arguments
                class FindDependentKernelVariables(ast.NodeVisitor):
                    def __init__(self, arg_names):
                        self.depends_on = []
                        self.outer_parent_arg_names = arg_names
                    def visit_Name(self, node):
                        if node.id in self.outer_parent_arg_names:
                            self.depends_on.append(node.id)
                        self.generic_visit(node)
                tree = ast.parse(code_to_exec)
                analyzer = FindDependentKernelVariables(self.arg_names)
                analyzer.visit(tree)

                # analyzer.depends_on now has all kernel arg variables, 
                # used in the construction of the matrix

                if analyzer.depends_on:                 
                    # Need arg structure, python code, and locals[vars] code
                    depends_on_str = re.search(r"\[([A-Za-z0-9_]+)\]", with_decomp_lines[i]).group(1)
                    arg_struct = ','.join([self.allowed_type_cpp_map[str(self.type_annotations[s])]+' '+s for s in depends_on_str.split(',')])
                    arg_var_names = ','.join([s for s in depends_on_str.split(',')])
                    locals_code = '\n'.join(['locals["{}"] = {};'.format(n,n) for n in arg_var_names])
                    self.extra_cpp_code = cpp_matrix_gen_code.format(with_decomp_matrix_names[i], arg_struct, code_to_exec, locals_code)
                    arg_struct = ','.join([self.allowed_type_cpp_map[str(
                        self.type_annotations[s])]+' '+s for s in analyzer.depends_on])
                    arg_var_names = ','.join(
                        [s for s in analyzer.depends_on])
                    locals_code = '\n'.join(
                        ['locals["{}"] = {};'.format(n, n) for n in arg_var_names])
                    self.extra_cpp_code = cpp_matrix_gen_code.format(
                        with_decomp_matrix_names[i], arg_struct, code_to_exec, locals_code)

                    col_skip = ' '*with_decomp_lines_col_starts[i]
                    new_src = col_skip + 'decompose {\n'
                    new_src += col_skip + ' '*4 + 'auto [mat_data, mat_size] = __internal__qcor_pyjit_gen_{}_unitary_matrix({});\n'.format(with_decomp_matrix_names[i], arg_var_names)
                    new_src += col_skip+' '*4 + 'UnitaryMatrix {} = Eigen::Map<UnitaryMatrix>(mat_data.data(), mat_size, mat_size);\n'.format(with_decomp_matrix_names[i])
                    new_src += col_skip + '{}({});\n'.format('}', decompose_args)
                    fbody_src = fbody_src.replace(total_decompose_code, new_src)
                    new_src += col_skip + ' '*4 + \
                        'auto [mat_data, mat_size] = __internal__qcor_pyjit_gen_{}_unitary_matrix({});\n'.format(
                            with_decomp_matrix_names[i], arg_var_names)
                    new_src += col_skip+' '*4 + \
                        'UnitaryMatrix {} = Eigen::Map<UnitaryMatrix>(mat_data.data(), mat_size, mat_size);\n'.format(
                            with_decomp_matrix_names[i])
                    new_src += col_skip + \
                        '{}({});\n'.format('}', decompose_args)
                    fbody_src = fbody_src.replace(
                        total_decompose_code, new_src)
                else:
                    # Execute the code, extract the matrix data and size
                    # This is the case where the matrix is static and does 
                    # not depend on any kernel arguments
                    _locals = locals()
                    exec(code_to_exec, globals(), _locals)
                    data = _locals['mat_data']
@@ -237,11 +278,14 @@ class qjit(object):
                    # Replace total_decompose_code in fbody_src...
                    col_skip = ' '*with_decomp_lines_col_starts[i]
                    new_src = col_skip + 'decompose {\n'
                    new_src += col_skip+' '*4 + 'UnitaryMatrix {} = UnitaryMatrix::Zero({},{});\n'.format(with_decomp_matrix_names[i], mat_size,mat_size)
                    new_src += col_skip+' '*4 + '{} << {};\n'.format(with_decomp_matrix_names[i], data)
                    new_src += col_skip + '{}({});\n'.format('}', decompose_args)
                    fbody_src = fbody_src.replace(total_decompose_code, new_src)

                    new_src += col_skip+' '*4 + 'UnitaryMatrix {} = UnitaryMatrix::Zero({},{});\n'.format(
                        with_decomp_matrix_names[i], mat_size, mat_size)
                    new_src += col_skip+' '*4 + \
                        '{} << {};\n'.format(with_decomp_matrix_names[i], data)
                    new_src += col_skip + \
                        '{}({});\n'.format('}', decompose_args)
                    fbody_src = fbody_src.replace(
                        total_decompose_code, new_src)

        # Users must provide arg types, if not we throw an error
        if not self.type_annotations or len(self.arg_names) != len(self.type_annotations):
@@ -308,7 +352,8 @@ class qjit(object):
        # Persist *pass by ref* variables to the accelerator buffer:
        persist_by_ref_var_code = ''
        for ref_var in self.ref_type_args:
            persist_by_ref_var_code += '\npersist_var_to_qreq(\"' + ref_var + '\", ' + ref_var + ', '+ self.qRegName + ')' 
            persist_by_ref_var_code += '\npersist_var_to_qreq(\"' + \
                ref_var + '\", ' + ref_var + ', ' + self.qRegName + ')'

        # Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
        self.src = '__qpu__ void '+self.function.__name__ + \
@@ -326,14 +371,14 @@ class qjit(object):
            if re.search(r"\b" + re.escape(kernelCall) + '|' + re.escape(kernelAdjCall) + '|' + re.escape(kernelCtrlCall), self.src):
                dependency.append(kernelName)


        self.__kernels__graph.addKernelDependency(
            self.function.__name__, dependency)
        self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
            self.function.__name__)

        # Run the QJIT compile step to store function pointers internally
        self._qjit.internal_python_jit_compile(self.src, self.sorted_kernel_dep, self.extra_cpp_code)
        self._qjit.internal_python_jit_compile(
            self.src, self.sorted_kernel_dep, self.extra_cpp_code)
        self._qjit.write_cache()
        self.__compiled__kernels.append(self.function.__name__)
        return
@@ -380,9 +425,11 @@ class qjit(object):

        # Local vars used to figure out if we have
        # arg structures that look like (qreg, float...)
        type_annots_list = [str(self.type_annotations[x]) for x in self.arg_names]
        type_annots_list = [str(self.type_annotations[x])
                            for x in self.arg_names]
        default_float_args = ['<class \'float\'>']
        intersection = list(set(type_annots_list[1:]) & set(default_float_args) ) 
        intersection = list(
            set(type_annots_list[1:]) & set(default_float_args))

        if intersection == default_float_args:
            # This handles all (qreg, float...)
@@ -479,8 +526,10 @@ class qjit(object):
            # Retrieve *original* variable names of the argument pack
            frame = inspect.currentframe()
            frame = inspect.getouterframes(frame)[1]
            code_context_string = inspect.getframeinfo(frame[0]).code_context[0].strip()
            caller_args = code_context_string[code_context_string.find('(') + 1:-1].split(',')
            code_context_string = inspect.getframeinfo(
                frame[0]).code_context[0].strip()
            caller_args = code_context_string[code_context_string.find(
                '(') + 1:-1].split(',')
            caller_var_names = []
            for i in caller_args:
                i = i.strip()
@@ -492,13 +541,16 @@ class qjit(object):
            # Get the updated value:
            for by_ref_var in self.ref_type_args:
                updated_var = qReg.getInformation(by_ref_var)
                caller_var_name = caller_var_names[self.arg_names.index(by_ref_var)]
                caller_var_name = caller_var_names[self.arg_names.index(
                    by_ref_var)]
                if (caller_var_name in inspect.stack()[1][0].f_globals):
                    # Make sure it is the correct type:
                    by_ref_instane = inspect.stack()[1][0].f_globals[caller_var_name] 
                    by_ref_instane = inspect.stack(
                    )[1][0].f_globals[caller_var_name]
                    # We only support float and int atm
                    if (isinstance(by_ref_instane, float) or isinstance(by_ref_instane, int)):
                        inspect.stack()[1][0].f_globals[caller_var_name] = updated_var
                        inspect.stack()[
                            1][0].f_globals[caller_var_name] = updated_var

        return

+1 −1
Original line number Diff line number Diff line
@@ -111,7 +111,7 @@ class TestKernelJIT(unittest.TestCase):
    #     @qjit
    #     def ansatz(q : qreg, x : List[float]):
    #         X(q[0])
    #         with decompose(q, kak, depends_on=[x]) as u:
    #         with decompose(q, kak) as u:
    #             from scipy.sparse.linalg import expm
    #             from openfermion.ops import QubitOperator
    #             from openfermion.transforms import get_sparse_operator