Commit 0f514327 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

add ability to construct qjit from qiskit or pyquil quantum circuits

parent 100a8164
Loading
Loading
Loading
Loading
+69 −5
Original line number Diff line number Diff line
@@ -194,12 +194,76 @@ class qjit(object):

    """

    def __get__qasm__generator__(self, python_data):
        """
        Query python_data to see if this is a qiskit.QuantumCircuit or 
        a pyquil.Program. Return the QASM string generator and the 
        correct XACC Compiler. 
        """
        qasm_gen = None
        compiler = None
        if hasattr(python_data, 'qasm'):
            # If this is not a function, see if it is a
            # Qiskit QuantumCircuit and process it (map to qjit kernel)
            qasm_gen = getattr(python_data, 'qasm')
            compiler = xacc.getCompiler('staq')
        elif hasattr(python_data, 'out'):
            qasm_gen = getattr(python_data, 'out')
            compiler = xacc.getCompiler('quilc')
        else:
            print('Invalid function-like instance passed to qjit.')
            exit(1)
        return qasm_gen, compiler
        
    def __convert__python__data__to__kernel__(self, python_data, *args, **kwargs):
        """
        Convert the incoming python data object (containing quantum circuit data) into 
        a python function adherent to the qcor QJIT kernel model. Also return the function 
        body source string
        """
        # Convert python data to a qasm_generator, run the generator
        # also get corresponding xacc Compiler
        qasm_str_gen, xacc_compiler = self.__get__qasm__generator__(python_data)
        qasm_str = qasm_str_gen()

        # generate unique function name, based on 
        # src hash so we get JIT cache benefit
        hash_object = hashlib.md5(qasm_str.encode('utf-8'))
        kernel_function_name = '__internal_qk_circuit_kernel_' + \
            str(hash_object.hexdigest())

        xacc_ir = xacc_compiler.compile(qasm_str)
        pyxasm_str = xacc.getCompiler('pyxasm').translate(
            xacc_ir.getComposites()[0], {'qreg_name': 'q'})
        pyxasm_str = ''.join(['    {}\n'.format(line)
                              for line in pyxasm_str.split('\n')])

        kernel_function_name = '__internal_qk_circuit_kernel_' + \
            str(hash_object.hexdigest())
        local_src = 'def {}(q : qreg):\n{}\n'.format(
            kernel_function_name, pyxasm_str)

        result = globals()
        exec(local_src, result)
        return result[kernel_function_name], pyxasm_str

    def __init__(self, function, *args, **kwargs):
        """Constructor for qjit, takes as input the annotated python function and any additional optional
        arguments that are used to customize the workflow."""
        self.args = args
        self.kwargs = kwargs

        if not callable(function):
            # Assume this is some pythonic data structure 
            # describing the quantum code (like qiskit QuantumCircuit, or pyquil Program)
            self.function, fbody_src = self.__convert__python__data__to__kernel__(
                    function, args, kwargs)
            # need to provide the function body since inspect wont be able to get it
            kwargs['__internal_fbody_src_provided__'] = fbody_src

        else:
            self.function = function
        
        self.allowed_type_cpp_map = {'<class \'_pyqcor.qreg\'>': 'qreg',
                                     '<class \'_pyqcor.qubit\'>': 'qubit',
                                     '<class \'float\'>': 'double', 'typing.List[float]': 'std::vector<double>',
@@ -666,24 +730,24 @@ class qjit(object):

    def mlir(self, *args, **kwargs):
        assert len(args) == len(self.arg_names), "Cannot generate MLIR, you did not provided the correct concrete kernel arguments."
        open_qasm_str = self.openqasm(*args)
        open_qasm_str = self.openqasm(*args).replace('OPENQASM 2.0', 'OPENQASM 3')
        return openqasm_to_mlir(open_qasm_str, self.kernel_name(), 
                        kwargs['add_entry_point'] if 'add_entry_point' in kwargs else True)
    
    def llvm_mlir(self, *args, **kwargs):
        assert len(args) == len(self.arg_names), "Cannot generate LLVM MLIR, you did not provided the correct concrete kernel arguments."
        open_qasm_str = self.openqasm(*args)
        open_qasm_str = self.openqasm(*args).replace('OPENQASM 2.0', 'OPENQASM 3')
        return openqasm_to_llvm_mlir(open_qasm_str, self.kernel_name(), 
                        kwargs['add_entry_point'] if 'add_entry_point' in kwargs else True)

    def llvm_ir(self, *args, **kwargs):
        assert len(args) == len(self.arg_names), "Cannot generate LLVM IR, you did not provided the correct concrete kernel arguments."
        open_qasm_str = self.openqasm(*args)
        open_qasm_str = self.openqasm(*args).replace('OPENQASM 2.0', 'OPENQASM 3')
        return openqasm_to_llvm_ir(open_qasm_str, self.kernel_name(), 
                        kwargs['add_entry_point'] if 'add_entry_point' in kwargs else True)

    def qir(self, *args, **kwargs):
        return llvm_ir(*args, **kwargs)
        return self.llvm_ir(*args, **kwargs)

    # Helper to construct the arg_dict (HetMap)
    # e.g. perform any additional type conversion if required.