Commit cab7d6db authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

work on adding instruction broadcast using value semantics instructions. work...


work on adding instruction broadcast using value semantics instructions. work on getting ctrl region working with value semantics, update gate definitions to return last qubit results to support value semantics, various bug fixes, most tests now passing as they were

Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 2daeb9bb
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -62,11 +62,17 @@ def ArrayConcatOp : QuantumOp<"qarray_concat", []> {
def StartCtrlURegion : QuantumOp<"start_ctrl_u_region", []> {
  let arguments = (ins);
  let results = (outs);
  let printer = [{  
  p << "q.ctrl_region {";
  }];
}

def EndCtrlURegion : QuantumOp<"end_ctrl_u_region", []> {
  let arguments = (ins QubitType:$ctrl_qubit);
  let results = (outs);
  let printer = [{  auto op = *this;
  p << "} (ctrl_bit = " << op.ctrl_qubit() << ")";
  }];
}

def StartAdjointURegion : QuantumOp<"start_adj_u_region", []> {
+138 −58
Original line number Diff line number Diff line
@@ -54,18 +54,29 @@ void qasm3_visitor::createInstOps_HandleBroadcast(
  if (has_array_type(qbit_values)) {
    if (qbit_values.size() == 1) {
      auto n = get_qreg_size(qbit_values[0], qreg_names[0]);

      for (int i = 0; i < n; i++) {
        auto qubit_type = get_custom_opaque_type("Qubit", builder.getContext());

        auto extract_value = builder.create<mlir::quantum::ExtractQubitOp>(
            location, qubit_type, qbit_values[0],
            get_or_create_constant_integer_value(
                i, location, builder.getI64Type(), symbol_table, builder));
        builder.create<mlir::quantum::InstOp>(
            location, mlir::NoneType::get(builder.getContext()), str_attr,
            llvm::makeArrayRef(std::vector<mlir::Value>{extract_value}),
        auto extract_value = get_or_extract_qubit(qreg_names[0], i, location,
                                                  symbol_table, builder);

        std::vector<mlir::Type> ret_types;
        for (auto q : qbit_values) {
          ret_types.push_back(qubit_type);
        }
        auto inst = builder.create<mlir::quantum::ValueSemanticsInstOp>(
            location, llvm::makeArrayRef(ret_types), str_attr,
            llvm::makeArrayRef(extract_value),
            llvm::makeArrayRef(param_values));

        // Replace qbit_values in symbol table with new result qubits
        auto return_vals = inst.result();
        int ii = 0;
        for (auto result : return_vals) {
          symbol_table.replace_symbol(qreg_names[0] + std::to_string(i),
                                      result);
          ii++;
        }
      }
    } else if (qbit_values.size() == 2) {
      if (qbit_values[0].getType() == array_type &&
@@ -73,6 +84,8 @@ void qasm3_visitor::createInstOps_HandleBroadcast(
        auto n = get_qreg_size(qbit_values[0], qreg_names[0]);
        auto m = get_qreg_size(qbit_values[1], qreg_names[1]);

        // This case is cx qarray, rarray;

        if (n != m) {
          printErrorMessage("Gate broadcast must be on registers of same size.",
                            context);
@@ -82,59 +95,94 @@ void qasm3_visitor::createInstOps_HandleBroadcast(
          auto qubit_type =
              get_custom_opaque_type("Qubit", builder.getContext());

          auto extract_value_n = builder.create<mlir::quantum::ExtractQubitOp>(
              location, qubit_type, qbit_values[0],
              get_or_create_constant_integer_value(
                  i, location, builder.getI64Type(), symbol_table, builder));
          auto extract_value_m = builder.create<mlir::quantum::ExtractQubitOp>(
              location, qubit_type, qbit_values[1],
              get_or_create_constant_integer_value(
                  i, location, builder.getI64Type(), symbol_table, builder));
          auto extract_value_n = get_or_extract_qubit(
              qreg_names[0], i, location, symbol_table, builder);
          auto extract_value_m = get_or_extract_qubit(
              qreg_names[1], i, location, symbol_table, builder);

          builder.create<mlir::quantum::InstOp>(
              location, mlir::NoneType::get(builder.getContext()), str_attr,
              llvm::makeArrayRef(
                  std::vector<mlir::Value>{extract_value_n, extract_value_m}),
          std::vector<mlir::Type> ret_types;
          for (auto q : qbit_values) {
            ret_types.push_back(qubit_type);
          }
          auto inst = builder.create<mlir::quantum::ValueSemanticsInstOp>(
              location, llvm::makeArrayRef(ret_types), str_attr,
              llvm::makeArrayRef({extract_value_n, extract_value_m}),
              llvm::makeArrayRef(param_values));

          // Replace qbit_values in symbol table with new result qubits
          auto return_vals = inst.result();
          int ii = 0;
          for (auto result : return_vals) {
            symbol_table.replace_symbol(qreg_names[ii] + std::to_string(i),
                                        result);
            ii++;
          }
        }

      } else if (qbit_values[0].getType() == array_type &&
                 qbit_values[1].getType() != array_type) {
        auto n = get_qreg_size(qbit_values[0], qreg_names[0]);
        mlir::Value v = qbit_values[1];

        for (int i = 0; i < n; i++) {
          auto qubit_type =
              get_custom_opaque_type("Qubit", builder.getContext());

          auto extract_value = builder.create<mlir::quantum::ExtractQubitOp>(
              location, qubit_type, qbit_values[0],
              get_or_create_constant_integer_value(
                  i, location, builder.getI64Type(), symbol_table, builder));
          // This case is cx qarray, r;

          builder.create<mlir::quantum::InstOp>(
              location, mlir::NoneType::get(builder.getContext()), str_attr,
              llvm::makeArrayRef(
                  std::vector<mlir::Value>{extract_value, qbit_values[1]}),
          auto extract_value = get_or_extract_qubit(qreg_names[0], i, location,
                                                    symbol_table, builder);

          std::vector<mlir::Type> ret_types;
          for (auto q : qbit_values) {
            ret_types.push_back(qubit_type);
          }
          auto inst = builder.create<mlir::quantum::ValueSemanticsInstOp>(
              location, llvm::makeArrayRef(ret_types), str_attr,
              llvm::makeArrayRef({extract_value, v}),
              llvm::makeArrayRef(param_values));

          // Replace qbit_values in symbol table with new result qubits
          auto return_vals = inst.result();
          int ii = 0;
          for (auto result : return_vals) {
            symbol_table.replace_symbol(
                qreg_names[ii] + (ii == 1 ? "" : std::to_string(i)), result);
            ii++;
          }
          v = return_vals[1];
        }
      } else if (qbit_values[0].getType() != array_type &&
                 qbit_values[1].getType() == array_type) {
        auto n = get_qreg_size(qbit_values[1], qreg_names[1]);
        // This is cx q, rarray

        mlir::Value v = qbit_values[0];
        for (int i = 0; i < n; i++) {
          auto qubit_type =
              get_custom_opaque_type("Qubit", builder.getContext());

          auto extract_value = builder.create<mlir::quantum::ExtractQubitOp>(
              location, qubit_type, qbit_values[1],
              get_or_create_constant_integer_value(
                  i, location, builder.getI64Type(), symbol_table, builder));
          auto extract_value = get_or_extract_qubit(qreg_names[1], i, location,
                                                    symbol_table, builder);

          builder.create<mlir::quantum::InstOp>(
              location, mlir::NoneType::get(builder.getContext()), str_attr,
              llvm::makeArrayRef(
                  std::vector<mlir::Value>{qbit_values[0], extract_value}),
          std::vector<mlir::Type> ret_types;
          for (auto q : qbit_values) {
            ret_types.push_back(qubit_type);
          }
          auto inst = builder.create<mlir::quantum::ValueSemanticsInstOp>(
              location, llvm::makeArrayRef(ret_types), str_attr,
              llvm::makeArrayRef({v, extract_value}),
              llvm::makeArrayRef(param_values));

          // Replace qbit_values in symbol table with new result qubits
          auto return_vals = inst.result();
          int ii = 0;
          for (auto result : return_vals) {
            symbol_table.replace_symbol(
                qreg_names[ii] + (ii == 0 ? "" : std::to_string(i)), result);
            ii++;
          }
          v = return_vals[0];
        }
      }
    } else {
@@ -142,14 +190,11 @@ void qasm3_visitor::createInstOps_HandleBroadcast(
          "can only broadcast gates with one or two qubit registers");
    }
  } else {
    // std::cout << "WE ARE HERE " << name << "\n";

    if (symbol_table_qbit_keys.empty()) {
      builder.create<mlir::quantum::InstOp>(
          location, mlir::NoneType::get(builder.getContext()), str_attr,
          llvm::makeArrayRef(qbit_values), llvm::makeArrayRef(param_values));
    } else {
      // std::cout << "SYMBOL TABLE KEYS WAS NOT EMPTY\n";
      std::vector<mlir::Type> ret_types;
      for (auto q : qbit_values) {
        ret_types.push_back(qubit_type);
@@ -254,8 +299,14 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
      auto idx_str = idx_identifier->expressionList()->expression(0)->getText();
      mlir::Value value;
      try {
        if (symbol_table.has_symbol(qbit_var_name + idx_str)) {
          value = symbol_table.get_symbol(qbit_var_name + idx_str);
        } else {
          // try catch is on this std::stoi(), if idx_str is not an integer,
          // then we drop out and try to evaluate the expression.
          value = get_or_extract_qubit(qbit_var_name, std::stoi(idx_str),
                                       location, symbol_table, builder);
        }
      } catch (...) {
        if (symbol_table.has_symbol(idx_str)) {
          auto qubits = symbol_table.get_symbol(qbit_var_name);
@@ -266,6 +317,8 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(

          value = builder.create<mlir::quantum::ExtractQubitOp>(
              location, qubit_type, qubits, qbit);
          if (!symbol_table.has_symbol(qbit_var_name + idx_str))
            symbol_table.add_symbol(qbit_var_name + idx_str, value);
        } else {
          qasm3_expression_generator exp_generator(builder, symbol_table,
                                                   file_name, qubit_type);
@@ -283,17 +336,15 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
            }
            value = builder.create<mlir::quantum::ExtractQubitOp>(
                location, qubit_type, qubits, value);
            if (!symbol_table.has_symbol(qbit_var_name + idx_str))
              symbol_table.add_symbol(qbit_var_name + idx_str, value);
          }
          // printErrorMessage(
          //     "Invalid measurement index on the given qubit register: " +
          //     qbit_var_name + ", " + idx_str);
        }
      }

      // auto qbit =
      //     get_or_extract_qubit(qbit_var_name, std::stoi(idx_str), location);
      qbit_values.push_back(value);
      qubit_symbol_table_keys.push_back(qbit_var_name + idx_str);

    } else {
      // this is a qubit
      auto qbit = symbol_table.get_symbol(qbit_var_name);
@@ -342,19 +393,29 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
  if (has_ctrl) {
    ctrl_bit = *qbit_values.begin();
    qbit_values.erase(qbit_values.begin());
    qubit_symbol_table_keys.erase(qubit_symbol_table_keys.begin());
  }

  if (symbol_table.has_seen_function(name)) {
    std::vector<mlir::Value> operands;
    std::vector<mlir::Type> result_types;
    for (auto p : param_values) {
      operands.push_back(p);
    }
    for (auto q : qbit_values) {
      operands.push_back(q);
      result_types.push_back(qubit_type);
    }

    builder.create<mlir::CallOp>(location, symbol_table.get_seen_function(name),
                                 operands);
    auto call_op = builder.create<mlir::CallOp>(
        location, symbol_table.get_seen_function(name), operands);

    auto return_vals = call_op.getResults();
    int i = 0;
    for (auto result : return_vals) {
      symbol_table.replace_symbol(qubit_symbol_table_keys[i], result);
      i++;
    }

  } else {
    createInstOps_HandleBroadcast(name, qbit_values, qreg_names,
@@ -369,6 +430,7 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
    } else if (top.first == EndAction::EndAdjU) {
      builder.create<mlir::quantum::EndAdjointURegion>(location);
    } else if (top.first == EndAction::EndCtrlU) {

      builder.create<mlir::quantum::EndCtrlURegion>(location, ctrl_bit);
    }
    action_and_extrainfo.pop();
@@ -604,15 +666,32 @@ antlrcpp::Any qasm3_visitor::visitSubroutineCall(

  auto str_attr = builder.getStringAttr(name);

  std::vector<std::string> qreg_names;
  std::vector<std::string> qreg_names, qubit_symbol_table_keys;
  auto n_qubit_args = expression_list[qubit_expr_list_idx]->expression().size();
  for (auto expression : expression_list[qubit_expr_list_idx]->expression()) {
    auto tmp_key = expression->getText();
    tmp_key.erase(std::remove(tmp_key.begin(), tmp_key.end(), '['),
                  tmp_key.end());
    tmp_key.erase(std::remove(tmp_key.begin(), tmp_key.end(), ']'),
                  tmp_key.end());
    qreg_names.push_back(tmp_key);

    mlir::Value tmp;
    if (symbol_table.has_symbol(tmp_key)) {
      tmp = symbol_table.get_symbol(tmp_key);
    } else {
      qasm3_expression_generator qubit_exp_generator(builder, symbol_table,
                                                     file_name, qubit_type);
      qubit_exp_generator.visit(expression);
      auto qbit_or_qreg = qubit_exp_generator.current_value;
    qbit_values.push_back(qubit_exp_generator.current_value);
    qreg_names.push_back(expression->getText());
      tmp = qbit_or_qreg;
      if (!symbol_table.has_symbol(tmp_key))
        symbol_table.add_symbol(tmp_key, tmp);
    }

    qbit_values.push_back(tmp);

    qubit_symbol_table_keys.push_back(tmp_key);
  }

  if (symbol_table.has_seen_function(name)) {
@@ -628,8 +707,9 @@ antlrcpp::Any qasm3_visitor::visitSubroutineCall(
                                 operands);

  } else {
    createInstOps_HandleBroadcast(name, qbit_values, qreg_names, {},
                                  param_values, location, context);
    createInstOps_HandleBroadcast(name, qbit_values, qreg_names,
                                  qubit_symbol_table_keys, param_values,
                                  location, context);
  }
  return 0;
}
+25 −7
Original line number Diff line number Diff line
@@ -36,8 +36,6 @@ antlrcpp::Any qasm3_visitor::visitQuantumDeclaration(
      try {
        size = std::stoi(exp_list->expression(0)->getText());
      } catch (...) {

        
        // check if this is a constant expression
        qasm3_expression_generator exp_generator(builder, symbol_table,
                                                 file_name);
@@ -52,7 +50,8 @@ antlrcpp::Any qasm3_visitor::visitQuantumDeclaration(
                "This variable qubit size must be a constant integer.");
          }
        } else {
          size = symbol_table.get_global_constant<int64_t>(exp_list->expression(0)->getText());
          size = symbol_table.get_global_constant<int64_t>(
              exp_list->expression(0)->getText());
        }
      }
    }
@@ -127,7 +126,7 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateDefinition(

  auto main_block = builder.saveInsertionPoint();

  auto func_type = builder.getFunctionType(func_args, llvm::None);
  auto func_type = builder.getFunctionType(func_args, func_args);
  auto proto =
      mlir::FuncOp::create(builder.getUnknownLoc(), gate_call_name, func_type);
  mlir::FuncOp function(proto);
@@ -146,7 +145,26 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateDefinition(

  auto ret = visitChildren(quantum_block);

  builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
  // Can I walk the use chain of the block arguments
  // and get the resultant qubit values taht I can then return
  // from this custom gate definition
  std::vector<mlir::Value> result_qubit_vals;
  for (auto arg : entryBlock.getArguments()) {
    auto users = arg.getUsers();
    mlir::Value last_user;
    if (!users.empty()) {
      last_user = (*users.begin())->getResult(0);
      users = last_user.getUsers();
    }
    result_qubit_vals.push_back(last_user);
  }

  std::cout << "GATE " << gate_call_name << " has " << result_qubit_vals.size() << " to return.\n";
  for (auto v : result_qubit_vals) {
    v.dump();
  }

  builder.create<mlir::ReturnOp>(builder.getUnknownLoc(), llvm::makeArrayRef(result_qubit_vals));

  m_module.push_back(function);

+8 −2
Original line number Diff line number Diff line
@@ -94,8 +94,14 @@ LogicalResult EndCtrlURegionOpLowering::matchAndRewrite(
    }
  }();

  rewriter.create<mlir::CallOp>(location, qir_get_fn_ptr,
                                LLVM::LLVMVoidType::get(context), operands);
  mlir::Value ctrl_bit = operands[0];
  if (auto q_op = ctrl_bit.getDefiningOp<mlir::quantum::ValueSemanticsInstOp>()) {
    ctrl_bit = q_op.getOperands()[0];
  }

  rewriter.create<mlir::CallOp>(
      location, qir_get_fn_ptr, LLVM::LLVMVoidType::get(context),
      llvm::makeArrayRef(std::vector<mlir::Value>{ctrl_bit}));

  rewriter.eraseOp(op);

+0 −6
Original line number Diff line number Diff line
@@ -63,7 +63,6 @@ LogicalResult ExtractQubitOpConversion::matchAndRewrite(

  auto get_qbit_qir_call = rewriter.create<mlir::CallOp>(
      location, symbol_ref, array_qbit_type, operands);
  // ArrayRef<Value>({vars[qreg_name], adaptor.idx()}));

  auto bitcast = rewriter.create<LLVM::BitcastOp>(
      location,
@@ -75,11 +74,6 @@ LogicalResult ExtractQubitOpConversion::matchAndRewrite(
      bitcast.res());

  rewriter.replaceOp(op, real_casted_qubit.res());
  // Remember the variable name for this qubit
  // vars.insert({qubit_var_name, real_casted_qubit.res()});

  // STORE THAT THIS OP PRODUCES THIS QREG{IDX} VARIABLE NAME
  // qubit_extract_map.insert({op, qubit_var_name});

  return success();
}
Loading