Commit 7c6bfa6a authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

work on subroutines, can take classical and quantum args. added some new tests to implement later

parent 1ff3c72f
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ file(GLOB SRC *.cpp antlr/generated/*.cpp utils/*.cpp
   visitor_handlers/measurement_handler.cpp
   visitor_handlers/loop_stmt_handler.cpp
   visitor_handlers/conditional_handler.cpp
   visitor_handlers/subroutine_handler.cpp
   )

add_library(${LIBRARY_NAME} SHARED ${SRC})
+1 −0
Original line number Diff line number Diff line
@@ -146,6 +146,7 @@ void OpenQasmV3MLIRGenerator::finalize_mlirgen() {
  }

  if (add_main) {
    builder.setInsertionPointToEnd(main_entry_block);
    builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                   llvm::ArrayRef<mlir::Value>());
  }
+4 −4
Original line number Diff line number Diff line
@@ -53,10 +53,10 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
      qasm3Parser::QuantumMeasurementAssignmentContext* context) override;

  // // see visitor_handlers/subroutine_handler.cpp
  // antlrcpp::Any visitSubroutineDefinition(
  //     qasm3Parser::SubroutineDefinitionContext* context) override;
  // antlrcpp::Any visitReturnStatement(
  //     qasm3Parser::ReturnStatementContext* context) override;
  antlrcpp::Any visitSubroutineDefinition(
      qasm3Parser::SubroutineDefinitionContext* context) override;
  antlrcpp::Any visitReturnStatement(
      qasm3Parser::ReturnStatementContext* context) override;

  // see visitor_handlers/conditional_handler.cpp
  antlrcpp::Any visitBranchingStatement(
+143 −0
Original line number Diff line number Diff line
@@ -271,6 +271,149 @@ while (i < 10) {
  std::cout << mlir << "\n";
  qcor::execute("qasm3", while_stmt, "while_stmt");
}


TEST(qasm3VisitorTester, checkSubroutine) {
  const std::string subroutine_test = R"#(OPENQASM 3;
include "qelib1.inc";
def xmeasure qubit:q -> bit { h q; return measure q; }
qubit q, qq[2];
bit r, rr[2];

rr[0] = xmeasure q;
r = xmeasure qq[0];
)#";
  auto mlir = qcor::mlir_compile("qasm3", subroutine_test, "subroutine_test",
                                 qcor::OutputType::MLIR, false);

  std::cout << "subroutine_test MLIR:\n" << mlir << "\n";

}

TEST(qasm3VisitorTester, checkSubroutine2) {
  const std::string subroutine_test = R"#(OPENQASM 3;
include "qelib1.inc";
def xcheck qubit[4]:d, qubit:a -> bit {
  // reset a;
  for i in [1: 3] cx d[i], a;
  return measure a;
}
qubit q;
const n = 10;
def parity(bit[n]:cin) -> bit {
  bit c;
  for i in [0: n-1] {
    c ^= cin[i];
  }
  return c;
}
)#";
  auto mlir = qcor::mlir_compile("qasm3", subroutine_test, "subroutine_test",
                                 qcor::OutputType::MLIR, false);

  std::cout << "subroutine_test MLIR:\n" << mlir << "\n";

}

TEST(qasm3VisitorTester, checkSubroutine3) {
  const std::string subroutine_test = R"#(OPENQASM 3;
include "qelib1.inc";
const n = 10;
def xmeasure qubit:q -> bit { h q; return measure q; }
def ymeasure qubit:q -> bit { s q; h q; return measure q; }

def pauli_measurement(bit[2*n]:spec) qubit[n]:q -> bit {
  bit b;
  for i in [0: n - 1] {
    bit temp;
    if(spec[i]==1 && spec[n+i]==0) { temp = xmeasure q[i]; }
    if(spec[i]==0 && spec[n+i]==1) { temp = measure q[i]; }
    if(spec[i]==1 && spec[n+i]==1) { temp = ymeasure q[i]; }
    b ^= temp;
  }
  return b;
}
)#";
  auto mlir = qcor::mlir_compile("qasm3", subroutine_test, "subroutine_test",
                                 qcor::OutputType::MLIR, false);

  std::cout << "subroutine_test MLIR:\n" << mlir << "\n";

}

TEST(qasm3VisitorTester, checkSubroutine4) {
  const std::string subroutine_test = R"#(OPENQASM 3;
include "qelib1.inc";
const buffer_size = 30;

def ymeasure qubit:q -> bit { s q; h q; return measure q; }

def test(int[32]:addr) qubit:q, qubit[buffer_size]:buffer {
  bit outcome;
  cy buffer[addr], q;
  outcome = ymeasure buffer[addr];
  if(outcome == 1) ry(pi / 2) q;
}
)#";
  auto mlir = qcor::mlir_compile("qasm3", subroutine_test, "subroutine_test",
                                 qcor::OutputType::MLIR, false);

  std::cout << "subroutine_test MLIR:\n" << mlir << "\n";

}

// TO IMPLEMENT 

TEST(qasm3VisitorTester, checkMeasureRange) {
  const std::string meas_range = R"#(OPENQASM 3;
include "qelib1.inc";
qubit a[4], b[4];
bit ans[5];
measure b[0:3] -> ans[0:3];
)#";
  auto mlir = qcor::mlir_compile("qasm3", meas_range, "meas_range",
                                 qcor::OutputType::MLIR, false);

  std::cout << "meas_range MLIR:\n" << mlir << "\n";

}

TEST(qasm3VisitorTester, checkGate) {
  const std::string gate_def = R"#(OPENQASM 3;
include "qelib1.inc";
gate cphase(x) a, b
{
  U(0, 0, x / 2) a;
  CX a, b;
  U(0, 0, -x / 2) b;
  CX a, b;
  U(0, 0, x / 2) b;
}
cphase(pi / 2) q[0], q[1];
)#";
  auto mlir = qcor::mlir_compile("qasm3", gate_def, "gate_def",
                                 qcor::OutputType::MLIR, false);

  std::cout << "gate_def MLIR:\n" << mlir << "\n";

}
TEST(qasm3VisitorTester, checkCastBitToInt) {

   const std::string cast_int = R"#(OPENQASM 3;
include "qelib1.inc";
bit c[4] = "1111";
int[4] t;
t = int[4](c);
print(t);
)#";
  auto mlir = qcor::mlir_compile("qasm3", cast_int, "cast_int",
                                 qcor::OutputType::MLIR, false);

  std::cout << "cast_int MLIR:\n" << mlir << "\n";
}



int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
+62 −6
Original line number Diff line number Diff line
@@ -36,7 +36,9 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
    // We have hit a closing on an index
    // std::cout << "TERMNODE:\n";
    indexed_variable_value = current_value;
    if (casting_indexed_integer_to_bool) {
      internal_value_type = builder.getIndexType();
    }
  } else if (node->getSymbol()->getText() == "]") {
    if (casting_indexed_integer_to_bool) {
      // We have an indexed integer in indexed_variable_value
@@ -50,7 +52,6 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
      // uint[4] b_in = 15; // b = 1111
      // bool(b_in[1]);

      
      // auto number_value = builder.create<mlir::LoadOp>(location,
      // indexed_variable_value, get_or_create_constant_index_value(0,
      // location)); number_value.dump(); auto idx_minus_1 =
@@ -85,9 +86,22 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
      update_current_value(and_value.result());
      casting_indexed_integer_to_bool = false;
    } else {
      if (internal_value_type.isa<mlir::OpaqueType>() &&
      if (internal_value_type.dyn_cast_or_null<mlir::OpaqueType>() &&
          internal_value_type.cast<mlir::OpaqueType>().getTypeData().str() ==
              "Qubit") {
        if (current_value.getType().isa<mlir::MemRefType>()) {
          if (current_value.getType().cast<mlir::MemRefType>().getRank() == 1 &&
              current_value.getType().cast<mlir::MemRefType>().getShape()[0] ==
                  1) {
            current_value = builder.create<mlir::LoadOp>(
                location, current_value,
                get_or_create_constant_index_value(0, location, 64,
                                                   symbol_table, builder));
          } else {
            printErrorMessage("Terminator ']' -> Invalid qubit array index: ",
                              current_value);
          }
        }
        update_current_value(builder.create<mlir::quantum::ExtractQubitOp>(
            location, get_custom_opaque_type("Qubit", builder.getContext()),
            indexed_variable_value, current_value));
@@ -111,7 +125,6 @@ antlrcpp::Any qasm3_expression_generator::visitComparsionExpression(
  auto location = get_location(builder, file_name, compare);

  if (auto relational_op = compare->relationalOperator()) {
   
    visitChildren(compare->expression(0));
    auto lhs = current_value;
    visitChildren(compare->expression(1));
@@ -518,7 +531,7 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
    qasm3Parser::ExpressionTerminatorContext* ctx) {
  auto location = get_location(builder, file_name, ctx);

  std::cout << "Analyze Expression Terminator: " << ctx->getText() << "\n";
  // std::cout << "Analyze Expression Terminator: " << ctx->getText() << "\n";

  if (ctx->Constant()) {
    auto const_str = ctx->Constant()->getText();
@@ -652,7 +665,50 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
    printErrorMessage(
        "We only support bool(int|uint|uint[i]) cast operations.");

  } else {
  } else if (auto sub_call = ctx->subroutineCall()) {
    // std::cout << "ARE WE HERE: " << ctx->subroutineCall()->getText() << "\n";
    // std::cout << ctx->subroutineCall()->Identifier()->getText() << ", "
    //           << ctx->subroutineCall()->expressionList(0)->getText() << "\n";

    auto func =
        symbol_table.get_seen_function(sub_call->Identifier()->getText());

    std::vector<mlir::Value> operands;

    auto qubit_expr_list_idx = 0;
    auto expression_list = sub_call->expressionList();
    if (expression_list.size() > 1) {
      // we have parameters
      qubit_expr_list_idx = 1;

      for (auto expression : expression_list[0]->expression()) {
        std::cout << "Subcall expr: " << expression->getText() << "\n";
        // add parameter values:
        // FIXME THIS SHOULD MATCH TYPES for FUNCTION
        auto value = std::stod(expression->getText());
        auto float_attr = mlir::FloatAttr::get(builder.getF64Type(), value);
        mlir::Value val =
            builder.create<mlir::ConstantOp>(location, float_attr);
        operands.push_back(val);
      }
    }

    for (auto expression : expression_list[qubit_expr_list_idx]->expression()) {
      qasm3_expression_generator qubit_exp_generator(
          builder, symbol_table, file_name,
          get_custom_opaque_type("Qubit", builder.getContext()));
      qubit_exp_generator.visit(expression);

      operands.push_back(qubit_exp_generator.current_value);
    }
    auto call_op = builder.create<mlir::CallOp>(location, func,
                                                llvm::makeArrayRef(operands));
    update_current_value(call_op.getResult(0));

    return 0;
  }

  else {
    printErrorMessage("Cannot handle this expression terminator yet: " +
                      ctx->getText());
  }
Loading