Commit 717dc509 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

[WIP] passing variables by referene via pyxasm



Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent be21c35c
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
from qcor import *
import math

# python3 test_ftqc.py -qrt ftqc

# Note: Must use FTQC runtime to get out_meas_z
@qjit
def test(q : qreg, out_meas_z: FLOAT_REF):
    H(q[0])
    if Measure(q[0]):
        out_meas_z = -1.0
    else:
        out_meas_z = 1.0

q = qalloc(1)
result = 0.0
test(q, result)
# Flipping 1.0; -1.0 (50-50)
print("Result =", result)
+7 −2
Original line number Diff line number Diff line
@@ -298,8 +298,13 @@ PYBIND11_MODULE(_pyqcor, m) {
      .def("size", &xacc::internal_compiler::qreg::size, "")
      .def("print", &xacc::internal_compiler::qreg::print, "")
      .def("counts", &xacc::internal_compiler::qreg::counts, "")
      .def("exp_val_z", &xacc::internal_compiler::qreg::exp_val_z, "");

      .def("exp_val_z", &xacc::internal_compiler::qreg::exp_val_z, "")
      .def(
          "getInformation",
          [](xacc::internal_compiler::qreg &q, const std::string &key) {
            return q.results()->getInformation(key);
          },
          "");
  // m.def("createObjectiveFunction", [](const std::string name, ))
  py::class_<qcor::QJIT, std::shared_ptr<qcor::QJIT>>(m, "QJIT", "")
      .def(py::init<>(), "")
+43 −2
Original line number Diff line number Diff line
@@ -9,7 +9,8 @@ from collections import defaultdict

List = typing.List
PauliOperator = xacc.quantum.PauliOperator

FLOAT_REF = typing.NewType('value', float)
INT_REF = typing.NewType('value', int)

def X(idx):
    return xacc.quantum.PauliOperator({idx: 'X'}, 1.0)
@@ -151,10 +152,17 @@ class qjit(object):

        # Construct the C++ kernel arg string
        cpp_arg_str = ''
        self.float_ref_args = []
        self.qRegName = ''
        for arg, _type in self.type_annotations.items():
            if _type is FLOAT_REF:
                _type = float
                self.float_ref_args.append(arg)
            if str(_type) not in self.allowed_type_cpp_map:
                print('Error, this quantum kernel arg type is not allowed: ', str(_type))
                exit(1)
            if self.allowed_type_cpp_map[str(_type)] == 'qreg':
                self.qRegName = arg
            cpp_arg_str += ',' + \
                self.allowed_type_cpp_map[str(_type)] + ' ' + arg
        cpp_arg_str = cpp_arg_str[1:]
@@ -192,10 +200,15 @@ class qjit(object):
                fbody_src = fbody_src.replace(
                    aliasModuleStr, originalModuleStr)

        # Persist *pass by ref* variables to the accelerator buffer:
        persist_by_ref_var_code = ''
        for ref_var in self.float_ref_args:
            persist_by_ref_var_code += '\npersist_var_to_qreq(\"' + ref_var + '\", ' + ref_var + ', '+ self.qRegName + ')' 

        # Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
        self.src = '__qpu__ void '+self.function.__name__ + \
            '('+cpp_arg_str+') {\nusing qcor::pyxasm;\n' + \
            globalDeclStr + '\n' + fbody_src + "}\n"
            globalDeclStr + '\n' + fbody_src + persist_by_ref_var_code + "}\n"

        # Handle nested kernels:
        dependency = []
@@ -331,6 +344,34 @@ class qjit(object):
        # Invoke the JITed function
        self._qjit.invoke(self.function.__name__, args_dict)
        
        # Update any *by-ref* arguments: annotated with the custom type: FLOAT_REF, INT_REF, etc.
        # If there are *pass-by-ref* variables:
        if len(self.float_ref_args) > 0:
            # Access the register:
            qReg = args_dict[self.qRegName]
            # Retrieve *original* variable names of the argument pack
            frame = inspect.currentframe()
            frame = inspect.getouterframes(frame)[1]
            code_context_string = inspect.getframeinfo(frame[0]).code_context[0].strip()
            caller_args = code_context_string[code_context_string.find('(') + 1:-1].split(',')
            caller_var_names = []
            for i in caller_args:
                i = i.strip()
                if i.find('=') != -1:
                    caller_var_names.append(i.split('=')[1].strip())
                else:
                    caller_var_names.append(i)
            
            # Get the updated value:
            for by_ref_var in self.float_ref_args:
                updated_var = qReg.getInformation(by_ref_var)
                caller_var_name = caller_var_names[self.arg_names.index(by_ref_var)]
                if (caller_var_name in inspect.stack()[1][0].f_globals):
                    # Make sure it is the correct type:
                    by_ref_instane = inspect.stack()[1][0].f_globals[caller_var_name] 
                    if (isinstance(by_ref_instane, float)):
                        inspect.stack()[1][0].f_globals[caller_var_name] = updated_var

        return


+5 −0
Original line number Diff line number Diff line
@@ -97,6 +97,11 @@ template <typename T, typename... TAIL> void print(const T &t, TAIL... tail) {
  print(tail...);
}

template <typename T>
void persist_var_to_qreq(const std::string &key, T &val, qreg &q) {
  q.results()->addExtraInfo(key, val);
}

// The TranslationFunctor maps vector<double> to a tuple of Args...
template <typename... Args>
using TranslationFunctor =