Commit 636ced52 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Translating ref types to Python JIT Kernels



The Python ref type (custom type) will be propagated to C++ kernels, e.g. int& or double&. This will make sure that nested kernels (a kernel calls other kernels) will support pass-by-reference appropriately.

The difficulty is *inline* unpacking to a r-val is not compatible with the by-reference ctor arg. Hence, adjust the codegen to upack those special cases to local var before passing to the ctor.

This will only occur at the Python-C++ interface level where we have a special mechanism to persist these pass-by-ref vars to the qreg and reflect them back to the Python side.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 5c3c816a
Loading
Loading
Loading
Loading
+117 −40
Original line number Diff line number Diff line
@@ -200,44 +200,44 @@ void QCORSyntaxHandler::GetReplacement(
  }
  OS << ") {}\n";

  if (add_het_map_ctor) {
    // Third constructor, give us a way to provide a HeterogeneousMap of
    // arguments, this is used for Pythonic QJIT...
    // KERNEL_NAME(HeterogeneousMap args);
    OS << kernel_name << "(HeterogeneousMap& args): QuantumKernel<"
       << kernel_name << ", " << program_arg_types[0];
    for (int i = 1; i < program_arg_types.size(); i++) {
      OS << ", " << program_arg_types[i];
    }
    OS << "> (args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    for (int i = 1; i < program_parameters.size(); i++) {
      OS << ", "
         << "args.get<" << program_arg_types[i] << ">(\""
         << program_parameters[i] << "\")";
    }
    OS << ") {}\n";

    // Forth constructor, give us a way to provide a HeterogeneousMap of
    // arguments, and set a parent kernel - this is also used for Pythonic
    // QJIT... KERNEL_NAME(std::shared_ptr<CompositeInstruction> parent,
    // HeterogeneousMap args);
    OS << kernel_name
       << "(std::shared_ptr<CompositeInstruction> parent, HeterogeneousMap& "
          "args): QuantumKernel<"
       << kernel_name << ", " << program_arg_types[0];
    for (int i = 1; i < program_arg_types.size(); i++) {
      OS << ", " << program_arg_types[i];
    }
    OS << "> (parent, args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    for (int i = 1; i < program_parameters.size(); i++) {
      OS << ", "
         << "args.get<" << program_arg_types[i] << ">(\""
         << program_parameters[i] << "\")";
    }
    OS << ") {}\n";
  }
  // if (add_het_map_ctor) {
  //   // Third constructor, give us a way to provide a HeterogeneousMap of
  //   // arguments, this is used for Pythonic QJIT...
  //   // KERNEL_NAME(HeterogeneousMap args);
  //   OS << kernel_name << "(HeterogeneousMap& args): QuantumKernel<"
  //      << kernel_name << ", " << program_arg_types[0];
  //   for (int i = 1; i < program_arg_types.size(); i++) {
  //     OS << ", " << program_arg_types[i];
  //   }
  //   OS << "> (args.get<" << program_arg_types[0] << ">(\""
  //      << program_parameters[0] << "\")";
  //   for (int i = 1; i < program_parameters.size(); i++) {
  //     OS << ", "
  //        << "args.get<" << program_arg_types[i] << ">(\""
  //        << program_parameters[i] << "\")";
  //   }
  //   OS << ") {}\n";

  //   // Forth constructor, give us a way to provide a HeterogeneousMap of
  //   // arguments, and set a parent kernel - this is also used for Pythonic
  //   // QJIT... KERNEL_NAME(std::shared_ptr<CompositeInstruction> parent,
  //   // HeterogeneousMap args);
  //   OS << kernel_name
  //      << "(std::shared_ptr<CompositeInstruction> parent, HeterogeneousMap& "
  //         "args): QuantumKernel<"
  //      << kernel_name << ", " << program_arg_types[0];
  //   for (int i = 1; i < program_arg_types.size(); i++) {
  //     OS << ", " << program_arg_types[i];
  //   }
  //   OS << "> (parent, args.get<" << program_arg_types[0] << ">(\""
  //      << program_parameters[0] << "\")";
  //   for (int i = 1; i < program_parameters.size(); i++) {
  //     OS << ", "
  //        << "args.get<" << program_arg_types[i] << ">(\""
  //        << program_parameters[i] << "\")";
  //   }
  //   OS << ") {}\n";
  // }

  // Destructor definition
  OS << "virtual ~" << kernel_name << "() {\n";
@@ -331,16 +331,93 @@ void QCORSyntaxHandler::GetReplacement(
  OS << "}\n";

  if (add_het_map_ctor) {
    // Remove "&" from type string before getting the Python variables in the HetMap.
    // Note: HetMap can't store references.
    const auto remove_ref_arg_type = [](const std::string &org_arg_type) -> std::string {
      // We intentially only support a very limited set of pass-by-ref types
      // from the HetMap.
      // Only do: double& and int&
      if (org_arg_type == "double&") {
        return "double";
      }
      if (org_arg_type == "int&") {
        return "int";
      }
      // Keep the type string.
      return org_arg_type;
    };

    // Strategy: we unpack the args in the HetMap and call
    // the appropriate ctor overload.
    
    // For reference ctor params (e.g. double& and int&),
    // we create a local variable to copy the arg from the HetMap
    // before passing to the ctor.
    // We have a special machanism to handle *pass-by-reference*
    // in the Python side.
    // Non-reference types will just use inline `args.get<T>(key)` to unpack
    // the arguments.

    // List of resolved argument strings for ctor calls.
    std::vector<std::string> arg_ctor_list;
    // Code to copy *ref* type arguments from the HetMap.
    // This *must* be injected before the ctor call.
    std::stringstream ref_type_copy_decl_ss;
    int var_counter = 0;
    // Only handle non-qreg args
    for (int i = 1; i < program_parameters.size(); i++) {
      // If this is a *supported* ref types: double&, int&, etc. 
      if (remove_ref_arg_type(program_arg_types[i]) != program_arg_types[i]) {
        // Generate a temp var
        const std::string new_var_name = "__temp_var__" + std::to_string(var_counter++);
        // Copy the var from HetMap to the temp var
        ref_type_copy_decl_ss << remove_ref_arg_type(program_arg_types[i]) << " "<< new_var_name << " = " << "args.get<" << remove_ref_arg_type(program_arg_types[i]) << ">(\""
         << program_parameters[i] << "\");\n";
        
        // We just pass this copied var to the ctor 
        // where it expects a reference type.
        arg_ctor_list.emplace_back(new_var_name); 
      }
      else {
        // Otherwise, just unpack the arg inline in the ctor call.
        std::stringstream ss;
        ss << "args.get<" << program_arg_types[i] << ">(\""<< program_parameters[i] << "\")";
        arg_ctor_list.emplace_back(ss.str());
      }
    }

    // Add the HeterogeneousMap args function overload
    OS << "void " << kernel_name
       << "__with_hetmap_args(HeterogeneousMap& args) {\n";
    OS << "class " << kernel_name << " __ker__temp__(args);\n";
    // First, inject any copying statements required to unpack *ref* types.
    OS << ref_type_copy_decl_ss.str();
    // CTor call
    OS << "class " << kernel_name << " __ker__temp__(";
    // First arg: qreg
    OS << "args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    // The rest: either inline unpacking or temp var names (ref type)
    for (const auto &arg_str: arg_ctor_list) {
      OS << ", " << arg_str;
    }
    OS << ");\n";
    OS << "}\n";

    OS << "void " << kernel_name
       << "__with_parent_and_hetmap_args(std::shared_ptr<CompositeInstruction> parent, "
          "HeterogeneousMap& args) {\n";
    OS << "class " << kernel_name << " __ker__temp__(parent, args);\n";
    OS << ref_type_copy_decl_ss.str();
    // CTor call with parent kernel
    OS << "class " << kernel_name << " __ker__temp__(parent, ";
    // Second arg: qreg
    OS << "args.get<" << program_arg_types[0] << ">(\""
       << program_parameters[0] << "\")";
    // The rest: either inline unpacking or temp var names (ref type)
    for (const auto &arg_str: arg_ctor_list) {
      OS << ", " << arg_str;
    }
    OS << ");\n";   
    // The rest: either inline unpacking or temp var names (ref type)
    OS << "}\n";
  }
  auto s = OS.str();
+22 −6
Original line number Diff line number Diff line
@@ -3,17 +3,33 @@ import math

# python3 test_ftqc.py -qrt ftqc

# Note: Must use FTQC runtime to get out_meas_z
@qjit
def test(q : qreg, out_meas_z: FLOAT_REF):
def testH0(q : qreg, out_meas_z: FLOAT_REF):
    print("Test H0: Input: ", out_meas_z)
    H(q[0])
    if Measure(q[0]):
        out_meas_z = -1.0
        out_meas_z = -1.0 + out_meas_z
    else:
        out_meas_z = 1.0 + out_meas_z
    print("Test H0: Output: ", out_meas_z)

@qjit
def testH1(q : qreg, out_meas_z: FLOAT_REF):
    print("Test H1: Input: ", out_meas_z)
    H(q[1])
    if Measure(q[1]):
        out_meas_z = -1.0 + out_meas_z
    else:
        out_meas_z = 1.0
        out_meas_z = 1.0 + out_meas_z
    print("Test H1: Output: ", out_meas_z)

# Note: Must use FTQC runtime to get out_meas_z
@qjit
def test(q : qreg, out_meas_z: FLOAT_REF):
    testH0(q, out_meas_z)
    testH1(q, out_meas_z)

q = qalloc(1)
q = qalloc(2)
result = 0.0
test(q, result)
# Flipping 1.0; -1.0 (50-50)
print("Result =", result)
+3 −1
Original line number Diff line number Diff line
@@ -156,8 +156,10 @@ class qjit(object):
        self.qRegName = ''
        for arg, _type in self.type_annotations.items():
            if _type is FLOAT_REF:
                _type = float
                self.float_ref_args.append(arg)
                cpp_arg_str += ',' + \
                    'double& ' + arg
                continue
            if str(_type) not in self.allowed_type_cpp_map:
                print('Error, this quantum kernel arg type is not allowed: ', str(_type))
                exit(1)