qcor.in.py 51.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# *******************************************************************************
# Copyright (c) 2018-, UT-Battelle, LLC.
# All rights reserved. This program and the accompanying materials
# are made available under the terms of the BSD 3-Clause License 
# which accompanies this distribution. 
#
# Contributors:
#   Alexander J. McCaskey - initial API and implementation
#   Thien Nguyen - implementation
# *******************************************************************************
11
import sys, uuid, atexit, hashlib
12
13
14
15

if '@QCOR_APPEND_PLUGIN_PATH@':
    sys.argv += ['__internal__add__plugin__path', '@QCOR_APPEND_PLUGIN_PATH@']

Mccaskey, Alex's avatar
Mccaskey, Alex committed
16
17
import xacc 

18
from _pyqcor import *
19
import inspect, ast
20
from typing import List
21
22
23
import typing, types
import re
import itertools
24
from collections import defaultdict
25

26
List = typing.List
27
Tuple = typing.Tuple
28
MethodType = types.MethodType
29
Callable = typing.Callable
30

31
32
33
34
35
36
37
# KernelSignature type annotation:
# Usage: annotate an function argument as a KernelSignature by:
# varName: KernelSignature(qreg, ...)
# Kernel always returns void (None)
def KernelSignature(*args):
    return Callable[list(args), None]

38
39
40
41
42
43
44
# Static cache of all Python QJIT objects that have been created.
# There seems to be a bug when a Python interpreter tried to create a new QJIT
# *after* a previous QJIT is destroyed.
# Note: this could only occur when QJIT kernels were declared in local scopes.
# i.e. multiple kernels all declared in global scope don't have this issue.
# Hence, to be safe, we cache all the QJIT objects ever created until QCOR module is unloaded.
QJIT_OBJ_CACHE = []
45
46
47
@atexit.register
def clear_qjit_cache():
    QJIT_OBJ_CACHE = []
48

49
50
FLOAT_REF = typing.NewType('value', float)
INT_REF = typing.NewType('value', int)
51

52
typing_to_simple_map = {'<class \'_pyqcor.qreg\'>': 'qreg',
53
                            '<class \'_pyqcor.qubit\'>': 'qubit',
54
55
                            '<class \'float\'>': 'float', 'typing.List[float]': 'List[float]',
                            '<class \'int\'>': 'int', 'typing.List[int]': 'List[int]',
56
                            '<class \'_pyqcor.Operator\'>': 'Operator',
57
                            'typing.List[typing.Tuple[int, int]]': 'List[Tuple[int,int]]',
58
                            'typing.List[_pyqcor.Operator]': 'List[Operator]'}
59

60
61
62
63
64
65
66
67
68
# Need to add a few extra header paths 
# for the clang code-gen mechanism. Mac OS X will 
# need QCOR_EXTRA_HEADERS, all will need the 
# Python include path.
extra_headers = ['-I'+'@Python_INCLUDE_DIRS@']
tmp_extra_headers = '@QCOR_EXTRA_HEADERS@'.replace('"','')
for path in tmp_extra_headers.split(';'):
    if path:
        extra_headers.append('-I'+path)
69

70
def X(idx):
71
    return Operator('pauli', 'X{}'.format(idx))
72

73
74

def Y(idx):
75
    return Operator('pauli', 'Y{}'.format(idx))
76

77
78

def Z(idx):
79
    return Operator('pauli', 'Z{}'.format(idx))
80

81
def adag(idx):
82
    return Operator('fermion', '1.0 {}^'.format(idx))
83
84

def a(idx):
85
    return Operator('fermion', '1.0 {}'.format(idx))
86

87
88
89
90
91
cpp_matrix_gen_code = '''#include <pybind11/embed.h>
#include <pybind11/stl.h>
#include <pybind11/complex.h>
namespace py = pybind11;
// returns 1d data as vector and matrix size (assume square)
92
auto __internal__qcor_pyjit_{}_gen_{}_unitary_matrix({}) {{
93
94
95
96
97
98
99
100
  auto py_src = R"#({})#";
  auto locals = py::dict();
  {}
  py::exec(py_src, py::globals(), locals);
  return std::make_pair(
      locals["mat_data"].cast<std::vector<std::complex<double>>>(), 
      locals["mat_size"].cast<int>());
}}'''
101

102
# Simple graph class to help resolve kernel dependency (via topological sort)
103
104


105
106
107
108
class KernelGraph(object):
    def __init__(self):
        self.graph = defaultdict(list)
        self.V = 0
109
110
111
        self.kernel_idx_dep_map = {}
        self.kernel_name_list = []

112
    def createKernelDependency(self, kernelName, depList):
113
114
115
        self.kernel_name_list.append(kernelName)
        self.kernel_idx_dep_map[self.V] = []
        for dep_ker_name in depList:
116
117
            self.kernel_idx_dep_map[self.V].append(
                self.kernel_name_list.index(dep_ker_name))
118
        self.V += 1
119

120
121
122
123
    def addKernelDependency(self, kernelName, newDep):
        self.kernel_idx_dep_map[self.kernel_name_list.index(kernelName)].append(
                self.kernel_name_list.index(newDep))

124
125
126
127
128
129
    def addEdge(self, u, v):
        self.graph[u].append(v)

    # Topological Sort.
    def topologicalSort(self):
        self.graph = defaultdict(list)
130
131
        for sub_ker_idx in self.kernel_idx_dep_map:
            for dep_sub_idx in self.kernel_idx_dep_map[sub_ker_idx]:
132
133
134
135
136
                self.addEdge(dep_sub_idx, sub_ker_idx)

        in_degree = [0]*(self.V)
        for i in self.graph:
            for j in self.graph[i]:
137
                in_degree[j] += 1
138
139
140
141
142

        queue = []
        for i in range(self.V):
            if in_degree[i] == 0:
                queue.append(i)
143
        cnt = 0
144
145
146
147
148
        top_order = []
        while queue:
            u = queue.pop(0)
            top_order.append(u)
            for i in self.graph[u]:
149
                in_degree[i] -= 1
150
151
                if in_degree[i] == 0:
                    queue.append(i)
152
            cnt += 1
153

154
155
        sortedDep = []
        for sorted_dep_idx in top_order:
156
            sortedDep.append(self.kernel_name_list[sorted_dep_idx])
157
158
159
160
161
162
163
        return sortedDep

    def getSortedDependency(self, kernelName):
        kernel_idx = self.kernel_name_list.index(kernelName)
        # No dependency
        if len(self.kernel_idx_dep_map[kernel_idx]) == 0:
            return []
164

165
166
167
168
169
170
171
        sorted_dep = self.topologicalSort()
        result_dep = []
        for dep_name in sorted_dep:
            if dep_name == kernelName:
                return result_dep
            else:
                result_dep.append(dep_name)
172

173

174
class qjit(object):
Mccaskey, Alex's avatar
Mccaskey, Alex committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    """
    The qjit class serves a python function decorator that enables 
    the just-in-time compilation of quantum python functions (kernels) using 
    the QCOR QJIT infrastructure. Example usage:

    @qjit
    def kernel(qbits : qreg, theta : float):
        X(q[0])
        Ry(q[1], theta)
        CX(q[1], q[0])
        for i in range(q.size()):
            Measure(q[i])

    q = qalloc(2)
    kernel(q)
    print(q.counts())

    Upon initialization, the python inspect module is used to extract the function body 
    as a string. This string is processed to create a corresponding C++ function with 
    pythonic function body as an embedded domain specific language. The QCOR QJIT engine 
    takes this function string, and delegates to the QCOR Clang SyntaxHandler infrastructure, which 
    maps this function to a QCOR QuantumKernel sub-type, compiles to LLVM bitcode, caches that bitcode 
    for future fast lookup, and extracts function pointers using the LLVM JIT engine that can be called 
    later, affecting execution of the quantum code. 

    Note that kernel function arguments must provide type hints, and allowed types are int, bool, float, List[float], and qreg. 

    qjit annotated functions can also be passed as general functors to other QCOR API calls like 
    createObjectiveFunction, and createModel from the QSim library. 

    """

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    def __get__qasm__generator__(self, python_data):
        """
        Query python_data to see if this is a qiskit.QuantumCircuit or 
        a pyquil.Program. Return the QASM string generator and the 
        correct XACC Compiler. 
        """
        qasm_gen = None
        compiler = None
        if hasattr(python_data, 'qasm'):
            # If this is not a function, see if it is a
            # Qiskit QuantumCircuit and process it (map to qjit kernel)
            qasm_gen = getattr(python_data, 'qasm')
            compiler = xacc.getCompiler('staq')
        elif hasattr(python_data, 'out'):
            qasm_gen = getattr(python_data, 'out')
            compiler = xacc.getCompiler('quilc')
        else:
            print('Invalid function-like instance passed to qjit.')
            exit(1)
        return qasm_gen, compiler
        
    def __convert__python__data__to__kernel__(self, python_data, *args, **kwargs):
        """
        Convert the incoming python data object (containing quantum circuit data) into 
        a python function adherent to the qcor QJIT kernel model. Also return the function 
        body source string
        """
        # Convert python data to a qasm_generator, run the generator
        # also get corresponding xacc Compiler
        qasm_str_gen, xacc_compiler = self.__get__qasm__generator__(python_data)
        qasm_str = qasm_str_gen()
238
        qasm_str = qasm_str.replace('cp', 'cu1')
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

        # generate unique function name, based on 
        # src hash so we get JIT cache benefit
        hash_object = hashlib.md5(qasm_str.encode('utf-8'))
        kernel_function_name = '__internal_qk_circuit_kernel_' + \
            str(hash_object.hexdigest())

        xacc_ir = xacc_compiler.compile(qasm_str)
        pyxasm_str = xacc.getCompiler('pyxasm').translate(
            xacc_ir.getComposites()[0], {'qreg_name': 'q'})
        pyxasm_str = ''.join(['    {}\n'.format(line)
                              for line in pyxasm_str.split('\n')])

        kernel_function_name = '__internal_qk_circuit_kernel_' + \
            str(hash_object.hexdigest())
        local_src = 'def {}(q : qreg):\n{}\n'.format(
            kernel_function_name, pyxasm_str)

        result = globals()
        exec(local_src, result)
        return result[kernel_function_name], pyxasm_str

261
    def __init__(self, function, *args, **kwargs):
Mccaskey, Alex's avatar
Mccaskey, Alex committed
262
263
        """Constructor for qjit, takes as input the annotated python function and any additional optional
        arguments that are used to customize the workflow."""
264
265
        self.args = args
        self.kwargs = kwargs
266
267
268
269
270
271
272
273
274
275
276
277

        if not callable(function):
            # Assume this is some pythonic data structure 
            # describing the quantum code (like qiskit QuantumCircuit, or pyquil Program)
            self.function, fbody_src = self.__convert__python__data__to__kernel__(
                    function, args, kwargs)
            # need to provide the function body since inspect wont be able to get it
            kwargs['__internal_fbody_src_provided__'] = fbody_src

        else:
            self.function = function
        
278
        self.allowed_type_cpp_map = {'<class \'_pyqcor.qreg\'>': 'qreg',
279
                                     '<class \'_pyqcor.qubit\'>': 'qubit',
280
                                     '<class \'float\'>': 'double', 'typing.List[float]': 'std::vector<double>',
281
                                     '<class \'int\'>': 'int', 'typing.List[int]': 'std::vector<int>',
282
                                     '<class \'_pyqcor.Operator\'>': 'qcor::Operator',
283
                                     'typing.List[typing.Tuple[int, int]]': 'PairList<int>',
284
                                     'typing.List[_pyqcor.Operator]': 'std::vector<qcor::Operator>'}
285
286
        self.__dict__.update(kwargs)

Mccaskey, Alex's avatar
Mccaskey, Alex committed
287
        # Create the qcor just in time engine
288
        self._qjit = QJIT()
289
        self.extra_cpp_code = ''
290

Mccaskey, Alex's avatar
Mccaskey, Alex committed
291
        # Get the kernel function body as a string
292
293
294
295
        if '__internal_fbody_src_provided__' in kwargs:
            fbody_src = kwargs['__internal_fbody_src_provided__']
        else:
            fbody_src = '\n'.join(inspect.getsource(self.function).split('\n')[2:])
296

297
298
299
300
        # Get the arg variable names and their types
        self.arg_names, _, _, _, _, _, self.type_annotations = inspect.getfullargspec(
            self.function)

301
302
303
        # Look at fbody_src, if with decompose is in there, then we
        # want to rewrite that portion to C++ here, that would be easiest.
        # strategy is going to be to run the decompose body code, get the
304
305
306
307
308
309
        # matrix as a 1d array, and rewrite to read it into UnitaryMatrix
        if 'with decompose' in fbody_src:
            # split the function into lines
            lines = fbody_src.split('\n')

            # Get all lines that are 'with decompose...'
310
311
            with_decomp_lines = [
                line for line in lines if 'with decompose' in line if line.lstrip()[0] != '#']
312
            # Get their index in the lines list
313
314
            with_decomp_lines_idxs = [
                lines.index(s) for s in with_decomp_lines]
315
            # Get their column start integer
316
317
            with_decomp_lines_col_starts = [
                sum(1 for _ in itertools.takewhile(str.isspace, s)) for s in with_decomp_lines]
318
            # Get the name of the matrix we are decomposing
319
320
            with_decomp_matrix_names = [line.split(
                ' ')[-1][:-1] for line in with_decomp_lines]
321
322
323
324
325
326
327
328

            # Loop over all decompose segments
            for i, line_idx in enumerate(with_decomp_lines_idxs):
                stmts_to_run = []
                total_decompose_code = with_decomp_lines[i]
                # Get all lines in the with decompose scope
                # ends if we hit a line with column dedent
                for line in lines[line_idx+1:]:
329
330
                    col_loc = sum(
                        1 for _ in itertools.takewhile(str.isspace, line))
331
332
333
334
                    if col_loc == with_decomp_lines_col_starts[i]:
                        break
                    total_decompose_code += '\n' + line
                    stmts_to_run.append(line.lstrip())
335

336
                # Get decompose args
337
338
                decompose_args = re.search(
                    '\(([^)]+)', with_decomp_lines[i]).group(1)
339
340
341
342
343
344
345
                d_list = decompose_args.split(',')
                decompose_args = d_list[0]
                for e in d_list[1:]:
                    if 'depends_on' in e:
                        break
                    decompose_args += ',' + e

346
                # Build up the matrix generation code
347
348
349
350
351
352
                code_to_exec = 'import numpy as np\n' + \
                    '\n'.join([s for s in stmts_to_run])
                code_to_exec += '\nmat_data = np.array(' + \
                    with_decomp_matrix_names[i]+').flatten()\n'
                code_to_exec += 'mat_size = ' + \
                    with_decomp_matrix_names[i]+'.shape[0]\n'
353
354
                # Users can use numpy. or np.
                code_to_exec = code_to_exec.replace('numpy.', 'np.')
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

                # Figure out it code_to_exec depends on any 
                # kernel arguments
                class FindDependentKernelVariables(ast.NodeVisitor):
                    def __init__(self, arg_names):
                        self.depends_on = []
                        self.outer_parent_arg_names = arg_names
                    def visit_Name(self, node):
                        if node.id in self.outer_parent_arg_names:
                            self.depends_on.append(node.id)
                        self.generic_visit(node)
                tree = ast.parse(code_to_exec)
                analyzer = FindDependentKernelVariables(self.arg_names)
                analyzer.visit(tree)

                # analyzer.depends_on now has all kernel arg variables, 
                # used in the construction of the matrix

                if analyzer.depends_on:                 
374
                    # Need arg structure, python code, and locals[vars] code
375
376
377
378
379
380
                    arg_struct = ','.join([self.allowed_type_cpp_map[str(
                        self.type_annotations[s])]+' '+s for s in analyzer.depends_on])
                    arg_var_names = ','.join(
                        [s for s in analyzer.depends_on])
                    locals_code = '\n'.join(
                        ['locals["{}"] = {};'.format(n, n) for n in arg_var_names])
381
                    self.extra_cpp_code = cpp_matrix_gen_code.format(self.kernel_name(),
382
                        with_decomp_matrix_names[i], arg_struct, code_to_exec, locals_code)
383
384
385

                    col_skip = ' '*with_decomp_lines_col_starts[i]
                    new_src = col_skip + 'decompose {\n'
386
                    new_src += col_skip + ' '*4 + \
387
                        'auto [mat_data, mat_size] = __internal__qcor_pyjit_{}_gen_{}_unitary_matrix({});\n'.format(self.kernel_name(),
388
389
390
391
392
393
394
395
                            with_decomp_matrix_names[i], arg_var_names)
                    new_src += col_skip+' '*4 + \
                        'UnitaryMatrix {} = Eigen::Map<UnitaryMatrix>(mat_data.data(), mat_size, mat_size);\n'.format(
                            with_decomp_matrix_names[i])
                    new_src += col_skip + \
                        '{}({});\n'.format('}', decompose_args)
                    fbody_src = fbody_src.replace(
                        total_decompose_code, new_src)
396
397
                else:
                    # Execute the code, extract the matrix data and size
398
399
                    # This is the case where the matrix is static and does 
                    # not depend on any kernel arguments
400
401
402
403
404
                    _locals = locals()
                    exec(code_to_exec, globals(), _locals)
                    data = _locals['mat_data']
                    data = ','.join([str(d) for d in data])
                    mat_size = _locals['mat_size']
405

406
407
408
                    # Replace total_decompose_code in fbody_src...
                    col_skip = ' '*with_decomp_lines_col_starts[i]
                    new_src = col_skip + 'decompose {\n'
409
410
411
412
413
414
415
416
                    new_src += col_skip+' '*4 + 'UnitaryMatrix {} = UnitaryMatrix::Zero({},{});\n'.format(
                        with_decomp_matrix_names[i], mat_size, mat_size)
                    new_src += col_skip+' '*4 + \
                        '{} << {};\n'.format(with_decomp_matrix_names[i], data)
                    new_src += col_skip + \
                        '{}({});\n'.format('}', decompose_args)
                    fbody_src = fbody_src.replace(
                        total_decompose_code, new_src)
Mccaskey, Alex's avatar
Mccaskey, Alex committed
417

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        if 'with compute' in fbody_src:
            # All we really should need to do is 
            # convert with compute to compute { 
            # and with action to } action {
            # then close with a } when we 
            # hit a new col location
            
            assert(fbody_src.count('with compute') == fbody_src.count('with action'))

            # split the function into lines
            lines = fbody_src.split('\n')
            new_src = ''
            in_action_block = False
            in_compute_block = False
            action_col, compute_col = (0, 0)
            for line in lines:
                current_col = sum(1 for _ in itertools.takewhile(str.isspace, line)) 
                if in_action_block and current_col <= action_col:
                    new_src += '}\n'

                if in_compute_block and current_col <= compute_col:
                    # here we have just dropped out of compute col
                    if 'with action' not in line:
                        print('After compute block, you must provide the action block.')
                        exit(1)

                if 'with compute' in line:
                    in_action_block = False
                    in_compute_block = True
                    compute_col = sum(1 for _ in itertools.takewhile(str.isspace, line))+1
                    new_src += 'compute {\n'
                elif 'with action' in line:
                    in_action_block = True
                    in_compute_block = False
                    action_col = sum(1 for _ in itertools.takewhile(str.isspace, line)) 
                    new_src += '} action {\n'
                else:
                    new_src += line + '\n'
                
            # update the source code
            fbody_src = new_src

Mccaskey, Alex's avatar
Mccaskey, Alex committed
460
        # Users must provide arg types, if not we throw an error
461
        if not self.type_annotations or len(self.arg_names) != len(self.type_annotations):
462
463
464
465
466
            print('Error, you must provide type annotations for qcor quantum kernels.')
            exit(1)

        # Construct the C++ kernel arg string
        cpp_arg_str = ''
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
467
        self.ref_type_args = []
468
        self.qRegName = ''
469
        for arg, _type in self.type_annotations.items():
470
            if _type is FLOAT_REF:
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
471
                self.ref_type_args.append(arg)
472
473
474
                cpp_arg_str += ',' + \
                    'double& ' + arg
                continue
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
475
476
477
478
479
            if _type is INT_REF:
                self.ref_type_args.append(arg)
                cpp_arg_str += ',' + \
                    'int& ' + arg
                continue
480
481
482
483
484
                
            # Helper to parse Python KernelSignature type annotation     
            def construct_callable_signature(clb_type):
                result_type_str = 'KernelSignature<'
                for i in range(len(clb_type.__args__) - 1):
485
                    # print("input type:", _type.__args__[i])
486
                    arg_type = clb_type.__args__[i]
487
                    if str(arg_type) not in self.allowed_type_cpp_map:
488
                        print('Error, this quantum kernel arg type is not allowed: ', str(clb_type))
489
490
                        print('\nAllowed kernel arg types:\n{}'.format('\n'.join([k.replace('<class ', '').replace(
                            '>', '').replace('typing.', '').replace("'", '') for k in self.allowed_type_cpp_map.keys()])))
491
                        exit(1)
492
493
                    result_type_str += self.allowed_type_cpp_map[str(arg_type)]
                    result_type_str += ','
494
                
495
496
497
498
499
500
501
                result_type_str = result_type_str[:-1]
                result_type_str += '>'
                return result_type_str

            # Single Callable argument
            if str(_type).startswith('typing.Callable'):
                cpp_type_str = construct_callable_signature(_type)
502
503
504
                # print("cpp type", cpp_type_str)
                cpp_arg_str += ',' + cpp_type_str + ' ' + arg
                continue
505
506
507
508
509
510
511
            # List of KernelSignature
            if str(_type).startswith('typing.List[typing.Callable'):
                # Note: All the Callables in the list must have the same signature.
                # (and they should to be considered equivalent for grouping into a List)
                cpp_type_str = 'std::vector<' + construct_callable_signature(_type.__args__[0]) + '>'
                cpp_arg_str += ',' + cpp_type_str + ' ' + arg
                continue
512

513
            if str(_type) not in self.allowed_type_cpp_map:
514
515
516
                print('Error, {} quantum kernel arg type is not allowed: '.format(arg), str(_type))
                print('\nAllowed kernel arg types:\n{}'.format('\n'.join([k.replace('<class ', '').replace(
                    '>', '').replace('typing.', '').replace("'", '') for k in self.allowed_type_cpp_map.keys()])))
517
                exit(1)
518
519
            if self.allowed_type_cpp_map[str(_type)] == 'qreg':
                self.qRegName = arg
520
521
522
523
            cpp_arg_str += ',' + \
                self.allowed_type_cpp_map[str(_type)] + ' ' + arg
        cpp_arg_str = cpp_arg_str[1:]

Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
524
525
526
527
528
529
530
531
532
533
534
535
        globalVarDecl = []
        # Get all globals currently defined at this stack frame
        globalsInStack = inspect.stack()[1][0].f_globals
        globalVars = globalsInStack.copy()
        importedModules = {}
        for key in globalVars:
            descStr = str(globalVars[key])
            # Cache module import and its potential alias
            # e.g. import abc as abc_alias
            if descStr.startswith("<module "):
                moduleName = descStr.split()[1].replace("'", "")
                importedModules[key] = moduleName
536
537
            elif key in fbody_src:
                # Import global variables (if used in the body):
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
538
539
540
                # Only support float atm
                if (isinstance(globalVars[key], float)):
                    globalVarDecl.append(key + " = " + str(globalVars[key]))
541

Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
542
543
544
545
        # Inject these global declarations into the function body.
        separator = "\n"
        globalDeclStr = separator.join(globalVarDecl)

546
        # Handle common modules like numpy or math
547
        # e.g. if seeing `import numpy as np`, we'll have <'np' -> 'numpy'> in the importedModules dict.
548
549
550
551
552
553
        # We'll replace any module alias by its original name,
        # i.e. 'np.pi' -> 'numpy.pi', etc.
        for moduleAlias in importedModules:
            if moduleAlias != importedModules[moduleAlias]:
                aliasModuleStr = moduleAlias + '.'
                originalModuleStr = importedModules[moduleAlias] + '.'
554
555
556
                fbody_src = fbody_src.replace(
                    aliasModuleStr, originalModuleStr)

557
558
        # Persist *pass by ref* variables to the accelerator buffer:
        persist_by_ref_var_code = ''
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
559
        for ref_var in self.ref_type_args:
560
            persist_by_ref_var_code += '\npersist_var_to_qreg(\"' + \
561
                ref_var + '\", ' + ref_var + ', ' + self.qRegName + ')'
562

Mccaskey, Alex's avatar
Mccaskey, Alex committed
563
        # Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
564
        self.src = '__qpu__ void '+self.function.__name__ + \
565
            '('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + \
566
            globalDeclStr + '\n' + fbody_src + persist_by_ref_var_code + "}\n"
567

568
569
570
571
        # Handle nested kernels:
        dependency = []
        for kernelName in self.__compiled__kernels:
            # Check that this kernel *calls* a previously-compiled kernel:
572
573
574
575
576
            # pattern: "<white space> kernel(" OR "kernel.adjoint(" OR "kernel.ctrl("
            kernelCall = kernelName + '('
            kernelAdjCall = kernelName + '.adjoint('
            kernelCtrlCall = kernelName + '.ctrl('
            if re.search(r"\b" + re.escape(kernelCall) + '|' + re.escape(kernelAdjCall) + '|' + re.escape(kernelCtrlCall), self.src):
577
                dependency.append(kernelName)
578

579
        self.__kernels__graph.createKernelDependency(
580
            self.function.__name__, dependency)
581
        self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(
582
583
            self.function.__name__)

584
        # print(self.src)
Mccaskey, Alex's avatar
Mccaskey, Alex committed
585
        # Run the QJIT compile step to store function pointers internally
586
        self._qjit.internal_python_jit_compile(
587
            self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers)
588
        self._qjit.write_cache()
589
        self.__compiled__kernels.append(self.function.__name__)
590
        QJIT_OBJ_CACHE.append(self)
591
592
        return

593
594
    # Static list of all kernels compiled
    __compiled__kernels = []
595
596
    __kernels__graph = KernelGraph()

597
598
599
600
601
602
603
    def get_syntax_handler_src(self):
        """
        Good for debugging purposes - return the actuall C++ code that the SyntaxHandler
        generates for this qjit kernel.
        """
        return self._qjit.run_syntax_handler(self.src)[1]

604
605
606
607
608
609
610
611
612
    def get_extra_cpp_code(self):
        """
        Return any required C++ code that the JIT source code will need.
        """
        return self.extra_cpp_code

    def get_sorted_kernels_deps(self):
        return self.sorted_kernel_dep

613
    def get_internal_src(self):
614
615
616
617
        """
        Return the C++ / embedded python DSL function code that will be passed to QJIT
        and the clang syntax handler. This function is primarily to be used for developer purposes.
        """
618
619
620
        return self.src

    def kernel_name(self):
Mccaskey, Alex's avatar
Mccaskey, Alex committed
621
        """Return the quantum kernel function name."""
622
623
624
        return self.function.__name__

    def translate(self, q: qreg, x: List[float]):
Mccaskey, Alex's avatar
Mccaskey, Alex committed
625
626
627
628
629
630
631
632
        """
        This method is primarily used internally to map Optimizer parameters x : List[float] to 
        the argument structure expected by the quantum kernel. For example, for a kernel 
        expecting (qreg, float) arguments, this method should return a dictionary where argument variable 
        names serve as keys, and values are corresponding argument instances. Specifically, the float 
        argument variable should point to x[0], for example. 
        """

633
        # Local vars used to figure out if we have
634
        # arg structures that look like (qreg, float...)
635
636
        type_annots_list = [str(self.type_annotations[x])
                            for x in self.arg_names]
637
        default_float_args = ['<class \'float\'>']
638
639
        intersection = list(
            set(type_annots_list[1:]) & set(default_float_args))
640
641
642

        if intersection == default_float_args:
            # This handles all (qreg, float...)
643
            ret_dict = {self.arg_names[0]: q}
644
645
            for i, arg_name in enumerate(self.arg_names[1:]):
                ret_dict[arg_name] = x[i]
646
            if len(ret_dict) != len(self.type_annotations):
Mccaskey, Alex's avatar
Mccaskey, Alex committed
647
                print(
648
                    'Error, could not translate vector parameters x into arguments for quantum kernel. ', len(ret_dict), len(self.type_annotations))
649
650
651
652
653
654
655
656
657
658
                exit(1)
            return ret_dict
        elif [str(x) for x in self.type_annotations.values()] == ['<class \'_pyqcor.qreg\'>', 'typing.List[float]']:
            ret_dict = {}
            for arg_name, _type in self.type_annotations.items():
                if str(_type) == '<class \'_pyqcor.qreg\'>':
                    ret_dict[arg_name] = q
                elif str(_type) == 'typing.List[float]':
                    ret_dict[arg_name] = x
            if len(ret_dict) != len(self.type_annotations):
Mccaskey, Alex's avatar
Mccaskey, Alex committed
659
660
                print(
                    'Error, could not translate vector parameters x into arguments for quantum kernel.')
661
662
                exit(1)
            return ret_dict
663
664
665
666
667
668
669
670
671
672
673
        elif [str(x) for x in self.type_annotations.values()] == ['<class \'_pyqcor.qreg\'>']:
            if len(x):
                print('invalid translate args, there is no x float list for this kernel.')
                exit(1)

            ret_dict = {}
            for arg_name, _type in self.type_annotations.items():
                if str(_type) == '<class \'_pyqcor.qreg\'>':
                    ret_dict[arg_name] = q
            return ret_dict

674
        else:
675
            print('currently cannot translate other arg structures: ', x)
676
677
678
            exit(1)

    def extract_composite(self, *args):
Mccaskey, Alex's avatar
Mccaskey, Alex committed
679
680
681
        """
        Convert the quantum kernel into an XACC CompositeInstruction
        """
682
        assert len(args) == len(self.arg_names), "Cannot create CompositeInstruction, you did not provided the correct kernel arguments."
683
        # Create a dictionary for the function arguments
684
        args_dict = self.construct_arg_dict(*args)
685
        return self._qjit.extract_composite(self.function.__name__, args_dict)
686

687
688
689
690
691
692
693
    def observe(self, observable, *args):
        """
        Return the expectation value of <observable> with 
        respect to the state given by this qjit kernel evaluated 
        at the given arguments. 
        """
        program = self.extract_composite(*args)
694
695
696
697
698
699
700
701
        # If the kernel has the simple signature (qreg, params...),
        # forwards the qreg to the qcor::observe method so that users can get a handle to the qreg
        # which contains child buffer information (e.g. bitstrings/exp-val-z of each term).
        if (str(self.type_annotations[self.arg_names[0]]) == '<class \'_pyqcor.qreg\'>' and args[0].size() == observable.nBits()):
            return internal_observe(program, observable, args[0])
        else:
            # Otherwise, just qcor will use a temp. buffer and just return the expectation value.
            return internal_observe(program, observable)
702
    
703
704
705
706
707
708
709
    def autograd(self, observable, qreg, x_vec):
        """
        Return the expectation value and gradients of <observable> with 
        respect to the state given by this qjit kernel evaluated 
        at the given arguments. 
        """
        def kernel_eval(x):
710
711
            args_dict = self.translate(qreg, x)
            return self._qjit.extract_composite(self.function.__name__, args_dict)
712
713
714
        
        if isinstance(x_vec, float):
            x_vec = [x_vec]
715

716
717
        return internal_autograd(kernel_eval, observable, x_vec)

Mccaskey, Alex's avatar
Mccaskey, Alex committed
718
719
720
721
722
723
724
    def openqasm(self, *args):
        """
        Return an OpenQasm string representation of this 
        quantum kernel.
        """
        kernel = self.extract_composite(*args)
        staq = xacc.getCompiler('staq')
725
        return staq.translate(kernel.as_xacc())
Mccaskey, Alex's avatar
Mccaskey, Alex committed
726

727
728
729
730
731
    def print_kernel(self, *args):
        """
        Print the QJIT kernel as a QASM-like string
        """
        print(self.extract_composite(*args).toString())
732

733
734
735
736
737
738
739
    def print_native_code(self, *args, **kwargs):
        """
        Print the native code targeting the Accelerator backend
        """
        args_dict = self.construct_arg_dict(*args)
        print(self._qjit.get_native_code(self.function.__name__, args_dict, kwargs))

740
741
742
743
744
    def n_instructions(self, *args):
        """
        Return the number of quantum instructions in this kernel. 
        """
        return self.extract_composite(*args).nInstructions()
745
746
    
    def as_unitary_matrix(self, *args):
747
        args_dict = self.construct_arg_dict(*args)
748
749
        return self._qjit.internal_as_unitary(self.function.__name__, args_dict)
    
750
    def ctrl(self, *args):
751
        assert False, 'This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.'
752
753

    def adjoint(self, *args):
754
755
        assert False, 'This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.'

756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
    def __internal_update_openqasm_qiskit_compat__(self, open_qasm_str, kernel_name):
        new_code = 'def {} qubit[{}]:{} {{\n'
        q_name = ''
        for line in open_qasm_str.split('\n'):
            if 'qreg' in line:
                q_name = line.split(' ')[1].split('[')[0]
                q_size = line.split(' ')[1].split('[')[1][:-2]
                continue
            if 'creg' in line:
                continue 
            if 'include' in line:
                continue
            if 'OPENQASM' in line:
                continue
            new_code += '  {}\n'.format(line)
        new_code += "}}"
        new_code = 'OPENQASM 3;\n' + new_code.format(kernel_name, q_size, q_name)
        return new_code

775
    def mlir(self, *args, **kwargs):
776
777
778
779
780
781
782
783
784
785
786
        assert len(args) == len(
            self.arg_names), "Cannot generate MLIR, you did not provided the correct concrete kernel arguments."
        open_qasm_str = self.openqasm(
            *args).replace('OPENQASM 2.0', 'OPENQASM 3')
        if 'qiskit_compat' in kwargs and kwargs['qiskit_compat']:
            open_qasm_str = self.__internal_update_openqasm_qiskit_compat__(
                open_qasm_str, kwargs['kernel_name'] if 'kernel_name' in kwargs else self.kernel_name())

        return openqasm_to_mlir(open_qasm_str, self.kernel_name() if 'kernel_name' not in kwargs else kwargs['kernel_name'],
                                kwargs['add_entry_point'] if 'add_entry_point' in kwargs else True, kwargs['opt'] if 'opt' in kwargs else 0, kwargs['qiskit_compat'] if 'qiskit_compat' in kwargs else False)

787
    def llvm_mlir(self, *args, **kwargs):
788
789
790
791
792
793
794
795
796
797
798
        assert len(args) == len(
            self.arg_names), "Cannot generate LLVM MLIR, you did not provided the correct concrete kernel arguments."
        open_qasm_str = self.openqasm(
            *args).replace('OPENQASM 2.0', 'OPENQASM 3')
        if 'qiskit_compat' in kwargs and kwargs['qiskit_compat']:
            open_qasm_str = self.__internal_update_openqasm_qiskit_compat__(
                open_qasm_str, kwargs['kernel_name'] if 'kernel_name' in kwargs else self.kernel_name())
            open_qasm_str.replace('u3', 'U')

        return openqasm_to_llvm_mlir(open_qasm_str, self.kernel_name() if 'kernel_name' not in kwargs else kwargs['kernel_name'],
                                     kwargs['add_entry_point'] if 'add_entry_point' in kwargs else True, kwargs['opt'] if 'opt' in kwargs else 0, kwargs['qiskit_compat'] if 'qiskit_compat' in kwargs else False)
799
800

    def llvm_ir(self, *args, **kwargs):
801
802
803
804
805
806
807
808
809
810
811
        assert len(args) == len(
            self.arg_names), "Cannot generate LLVM IR, you did not provided the correct concrete kernel arguments."
        open_qasm_str = self.openqasm(
            *args).replace('OPENQASM 2.0', 'OPENQASM 3')
        # print(open_qasm_str)
        if 'qiskit_compat' in kwargs and kwargs['qiskit_compat']:
            open_qasm_str = self.__internal_update_openqasm_qiskit_compat__(
                open_qasm_str, kwargs['kernel_name'] if 'kernel_name' in kwargs else self.kernel_name())

        return openqasm_to_llvm_ir(open_qasm_str, self.kernel_name() if 'kernel_name' not in kwargs else kwargs['kernel_name'],
                                   kwargs['add_entry_point'] if 'add_entry_point' in kwargs else True, kwargs['opt'] if 'opt' in kwargs else 0, kwargs['qiskit_compat'] if 'qiskit_compat' in kwargs else False)
812
813

    def qir(self, *args, **kwargs):
814
        return self.llvm_ir(*args, **kwargs)
815

816
817
818
    # Helper to construct the arg_dict (HetMap)
    # e.g. perform any additional type conversion if required.
    def construct_arg_dict(self, *args):
819
        # Create a dictionary for the function arguments
820
        args_dict = {}
821
        for i, arg_name in enumerate(self.arg_names):
822
            args_dict[arg_name] = list(args)[i]
823
824
            arg_type_str = str(self.type_annotations[arg_name])
            if arg_type_str.startswith('typing.Callable'):
825
826
                # print("callable:", arg_name)
                # print("arg:", type(args_dict[arg_name]))
827
828
829
830
831
832
                # the arg must be a qjit
                if not isinstance(args_dict[arg_name], qjit):
                    print('Invalid argument type for {}. A quantum kernel (qjit) is expected.'.format(arg_name))
                    exit(1)
                
                callable_qjit = args_dict[arg_name]
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
                
                # Handle runtime dependency:
                # The QJIT arg. was not *known* until invocation,
                # hence, we recompile the this jit kernel taking into account 
                # the KernelSignature argument.
                # TODO: perhaps an optimization that we can make is to
                # skip *eager* compilation for those kernels that have 
                # KernelSignature arguments.
                if callable_qjit.kernel_name() not in self.sorted_kernel_dep:
                    # print('New kernel:', callable_qjit.kernel_name())
                    # IMPORTANT: we cannot release a QJIT object till shut-down.
                    QJIT_OBJ_CACHE.append(self._qjit) 
                    # Create a new QJIT
                    self._qjit = QJIT()
                    # Add a kernel dependency
                    self.__kernels__graph.addKernelDependency(self.function.__name__, callable_qjit.kernel_name())
                    self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
                    # Recompile:
                    self._qjit.internal_python_jit_compile(self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers)
                
                # This should always be successful.
854
                fn_ptr = self._qjit.get_kernel_function_ptr(callable_qjit.kernel_name())
855
856
857
858
859
                if fn_ptr == 0:
                    print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(callable_qjit.kernel_name()))
                    exit(1)
                # Replace the argument (in the dict) with the function pointer
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
860
                args_dict[arg_name] = hex(fn_ptr)
861
862
863
864
865
866
            # List of callables:
            if arg_type_str.startswith('typing.List[typing.Callable['):
                callable_qjit_list = args_dict[arg_name]
                need_recompile = False
                for clb in callable_qjit_list:
                    if clb.kernel_name() not in self.sorted_kernel_dep:
867
                        # print('New kernel:', clb.kernel_name())
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
                        # Add a kernel dependency
                        self.__kernels__graph.addKernelDependency(self.function.__name__, clb.kernel_name())
                        self.sorted_kernel_dep = self.__kernels__graph.getSortedDependency(self.function.__name__)
                        need_recompile = True
                
                if need_recompile:
                    # Create a new QJIT
                    self._qjit = QJIT()
                    self._qjit.internal_python_jit_compile(self.src, self.sorted_kernel_dep, self.extra_cpp_code, extra_headers)
                
                clb_fn_ptrs = []
                for clb in callable_qjit_list:
                    if not isinstance(clb, qjit):
                        print('Invalid argument type for {}. A list of quantum kernels (qjit) is expected.'.format(arg_name))
                        exit(1)
883
                    # print("Kernel name:", clb.kernel_name())
884
885
886
887
888
889
890
891
892
893
894
                    # This should always be successful.
                    fn_ptr = self._qjit.get_kernel_function_ptr(clb.kernel_name())
                    if fn_ptr == 0:
                        print('Failed to retrieve JIT-compiled function pointer for qjit kernel {}.'.format(clb.kernel_name()))
                        exit(1)
                    clb_fn_ptrs.append(hex(fn_ptr))
                    
                # Replace the argument (in the dict) with the list of function pointers
                # qjit is a pure-Python object, hence cannot be used by native QCOR.
                args_dict[arg_name] = clb_fn_ptrs
        
895
896
897
898
899
900
901
        return args_dict

    def __call__(self, *args):
        """
        Execute the decorated quantum kernel. This will directly 
        invoke the corresponding LLVM JITed function pointer. 
        """
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
902
        args_dict = self.construct_arg_dict(*args)
903
        # Invoke the JITed function
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
904
        self._qjit.invoke(self.function.__name__, args_dict)
905

906
907
        # Update any *by-ref* arguments: annotated with the custom type: FLOAT_REF, INT_REF, etc.
        # If there are *pass-by-ref* variables:
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
908
        if len(self.ref_type_args) > 0:
909
910
911
912
913
            # Access the register:
            qReg = args_dict[self.qRegName]
            # Retrieve *original* variable names of the argument pack
            frame = inspect.currentframe()
            frame = inspect.getouterframes(frame)[1]
914
915
916
917
            code_context_string = inspect.getframeinfo(
                frame[0]).code_context[0].strip()
            caller_args = code_context_string[code_context_string.find(
                '(') + 1:-1].split(',')
918
919
920
921
922
923
924
            caller_var_names = []
            for i in caller_args:
                i = i.strip()
                if i.find('=') != -1:
                    caller_var_names.append(i.split('=')[1].strip())
                else:
                    caller_var_names.append(i)
925

926
            # Get the updated value:
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
927
            for by_ref_var in self.ref_type_args:
928
                updated_var = qReg.getInformation(by_ref_var)
929
930
                caller_var_name = caller_var_names[self.arg_names.index(
                    by_ref_var)]
931
932
                if (caller_var_name in inspect.stack()[1][0].f_globals):
                    # Make sure it is the correct type:
933
934
                    by_ref_instane = inspect.stack(
                    )[1][0].f_globals[caller_var_name]
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
935
936
                    # We only support float and int atm
                    if (isinstance(by_ref_instane, float) or isinstance(by_ref_instane, int)):
937
938
                        inspect.stack()[
                            1][0].f_globals[caller_var_name] = updated_var
939
940
941

        return

942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
class KernelBuilder(object):
    """
    The QCOR KernelBuilder is a high-level data structure that enables the 
    development of qcor quantum kernels programmatically in Python. Example usage:

    from qcor import * 

    nq = 10
    builder = KernelBuilder() 

    builder.h(0)
    for i in range(nq-1):
        builder.cnot(i, i+1)
    builder.measure_all()
    ghz = builder.create()

    q = qalloc(nq)
    ghz(q)
    print(q.counts())

Mccaskey, Alex's avatar
Mccaskey, Alex committed
962
963
    If you do not provide a qreg argument to the constructor (py_args_dict) 
    we will assume a single qreg named q.
964
    """
965
966
    def __init__(self,**kwargs):
        self.kernel_args = kwargs['kernel_args'] if 'kernel_args' in kwargs else {}
Mccaskey, Alex's avatar
Mccaskey, Alex committed
967
        # Returns list of tuples, (name, nRequiredBits, isParameterized)
968
        all_instructions = internal_get_all_instructions()
969
        all_instructions = [element for element in all_instructions if element[0] != 'Measure']
970
971
        self.qjit_str = ''
        self.qreg_name = 'q'
Mccaskey, Alex's avatar
Mccaskey, Alex committed
972
        self.TAB = '    '
973
974
975
976
977

        for instruction in all_instructions:
            isParameterized = instruction[2]
            n_bits = instruction[1]
            name = instruction[0]
Mccaskey, Alex's avatar
Mccaskey, Alex committed
978

979
980
981
982
983
984
985
986
987
            qbits_str = ','.join(['q{}'.format(i) for i in range(n_bits)])
            qbits_indexed = ','.join(["{}[{{}}]".format(self.qreg_name) for i in range(n_bits)])
            new_func_str = '''def {}(self, {}, *args):
    params_str = ''
    params = []
    if len(args):
        for arg in args:
            if isinstance(arg, str):
                params.append(str(arg))
988
989
                if str(arg) not in self.kernel_args:
                    self.kernel_args[str(arg)] = float
990
991
            elif isinstance(arg, tuple):
                params.append(arg[0]+'['+str(arg[1])+']')
992
993
                if arg[0] not in self.kernel_args:
                    self.kernel_args[arg[0]] = List[float]
994
995
996
997
998
999
1000
            else:
                print('[KernelBuilder Error] Invalid parameter type.')
                exit(1)
        params_str = ','.join(params)
    if {} and len(args) == 0:
        print("[KernelBuilder Error] You are calling a parameterized instruction ({}) but have not provided any parameters")
        exit(1)
For faster browsing, not all history is shown. View entire blame