Commit 1ff3c72f authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

added loops back, if stmts, complex conditionals.

parent dc1b1002
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -10,6 +10,9 @@ file(GLOB SRC *.cpp antlr/generated/*.cpp utils/*.cpp
   visitor_handlers/quantum_types_handler.cpp
   visitor_handlers/quantum_instruction_handler.cpp
   visitor_handlers/classical_types_handler.cpp
   visitor_handlers/measurement_handler.cpp
   visitor_handlers/loop_stmt_handler.cpp
   visitor_handlers/conditional_handler.cpp
   )

add_library(${LIBRARY_NAME} SHARED ${SRC})
+12 −56
Original line number Diff line number Diff line
@@ -47,10 +47,10 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
      qasm3Parser::KernelCallContext* context) override;

  // see visitor_handlers/measurement_handler.cpp
  // antlrcpp::Any visitQuantumMeasurement(
  //     qasm3Parser::QuantumMeasurementContext* context) override;
  // antlrcpp::Any visitQuantumMeasurementAssignment(
  //     qasm3Parser::QuantumMeasurementAssignmentContext* context) override;
  antlrcpp::Any visitQuantumMeasurement(
      qasm3Parser::QuantumMeasurementContext* context) override;
  antlrcpp::Any visitQuantumMeasurementAssignment(
      qasm3Parser::QuantumMeasurementAssignmentContext* context) override;

  // // see visitor_handlers/subroutine_handler.cpp
  // antlrcpp::Any visitSubroutineDefinition(
@@ -58,15 +58,15 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  // antlrcpp::Any visitReturnStatement(
  //     qasm3Parser::ReturnStatementContext* context) override;

  // // see visitor_handlers/conditional_handler.cpp
  // antlrcpp::Any visitBranchingStatement(
  //     qasm3Parser::BranchingStatementContext* context) override;
  // see visitor_handlers/conditional_handler.cpp
  antlrcpp::Any visitBranchingStatement(
      qasm3Parser::BranchingStatementContext* context) override;

  // // see visitor_handlers/for_stmt_handler.cpp
  // antlrcpp::Any visitLoopStatement(
  //     qasm3Parser::LoopStatementContext* context) override;
  // antlrcpp::Any visitControlDirective(
  //     qasm3Parser::ControlDirectiveContext* context) override;
  // see visitor_handlers/for_stmt_handler.cpp
  antlrcpp::Any visitLoopStatement(
      qasm3Parser::LoopStatementContext* context) override;
  antlrcpp::Any visitControlDirective(
      qasm3Parser::ControlDirectiveContext* context) override;

  // see visitor_handlers/classical_types_handler.cpp
  antlrcpp::Any visitConstantDeclaration(
@@ -125,50 +125,6 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
    return;
  }

  // mlir::Value get_or_extract_qubit(const std::string& qreg_name,
  //                                  const std::size_t idx,
  //                                  mlir::Location location) {
  //   auto key = qreg_name + std::to_string(idx);
  //   if (symbol_table.has_symbol(key)) {
  //     return symbol_table.get_symbol(key);  // global_symbol_table[key];
  //   } else {
  //     auto qubits = symbol_table.get_symbol(qreg_name);
  //     // .getDefiningOp<mlir::quantum::QallocOp>()
  //     // .qubits();
  //     mlir::Value pos = get_or_create_constant_integer_value(idx, location);

  //     // auto pos = create_constant_integer_value(idx, location);
  //     auto value = builder.create<mlir::quantum::ExtractQubitOp>(
  //         location, qubit_type, qubits, pos);
  //     symbol_table.add_symbol(key, value);
  //     return value;
  //   }
  // }

  // mlir::Value get_or_create_constant_integer_value(const std::size_t idx,
  //                                                  mlir::Location location,
  //                                                  int width = 64) {
  //   if (symbol_table.has_constant_integer(idx, width)) {
  //     return symbol_table.get_constant_integer(idx, width);
  //   } else {
  //     auto integer_attr =
  //         mlir::IntegerAttr::get(builder.getIntegerType(width), idx);

  //     auto ret = builder.create<mlir::ConstantOp>(location, integer_attr);
  //     symbol_table.add_constant_integer(idx, ret, width);
  //     return ret;
  //   }
  // }

  // mlir::Value get_or_create_constant_index_value(const std::size_t idx,
  //                                                mlir::Location location,
  //                                                int width = 64) {
  //   auto constant_int =
  //       get_or_create_constant_integer_value(idx, location, width);
  //   return builder.create<mlir::IndexCastOp>(location, constant_int,
  //                                            builder.getIndexType());
  // }

  // This function serves as a utility for creating a MemRef and
  // corresponding AllocOp of a given 1d shape. It will also store
  // initial values to all elements of the 1d array.
+196 −1
Original line number Diff line number Diff line
@@ -75,9 +75,204 @@ print(ff);
  qcor::execute("qasm3", src, "test");
}

TEST(qasm3VisitorTester, checkMeasurements) {
  const std::string measure_test = R"#(OPENQASM 3;
include "qelib1.inc";
qubit q;
qubit qq[2];

bit r, s;
r = measure q;
s = measure qq[0];

bit rr[2];
rr[0] = measure qq[0];

bit xx[4];
qubit qqq[4];
xx = measure qqq;

// bit y, yy[2];
// measure q -> y;
// measure qq -> yy;
// measure qq[0] -> y;
)#";
  auto mlir = qcor::mlir_compile("qasm3", measure_test, "measure_test",
                                 qcor::OutputType::MLIR, false);
  std::cout << "MLIR:\n" << mlir << "\n";
  // qcor::execute("qasm3", measure_test, "test");
}

TEST(qasm3VisitorTester, checkQuantumInsts) {
  const std::string qinst_test = R"#(OPENQASM 3;
include "qelib1.inc";
qubit q;
h q;
ry(2.2) q;

qubit qq[2];
x qq[0];
CX qq[0], qq[1];
U(0.1,0.2,0.3) qq[1];
cx q, qq[1];

)#";
  auto mlir = qcor::mlir_compile("qasm3", qinst_test, "qinst_test",
                                 qcor::OutputType::MLIR, false);
  std::cout << "MLIR:\n" << mlir << "\n";
}

TEST(qasm3VisitorTester, checkLoopStmt) {
  const std::string for_stmt = R"#(OPENQASM 3;
include "qelib1.inc";

for i in {11,22,33} {
    print(i);
}

for i in [0:10] {
    print(i);
    
}
for j in [0:2:4] {
    print("steps:", j);
}


for j in [0:4] {
    print("j in 0:4", j);
}

for i in [0:4] {
 for j in {1,2,3} {
     print(i,j);
 }
}
)#";
  auto mlir = qcor::mlir_compile("qasm3", for_stmt, "for_stmt",
                                 qcor::OutputType::MLIR, false);
  std::cout << "for_stmt MLIR:\n" << mlir << "\n";
  qcor::execute("qasm3", for_stmt, "for_stmt");
}

TEST(qasm3VisitorTester, checkUintIndexing) {
  const std::string uint_index = R"#(OPENQASM 3;
include "qelib1.inc";

uint[4] b_in = 15;

bool b1 = bool(b_in[0]);
bool b2 = bool(b_in[1]);
bool b3 = bool(b_in[2]);
bool b4 = bool(b_in[3]);

print(b1,b2,b3,b4);
)#";
  auto mlir = qcor::mlir_compile("qasm3", uint_index, "uint_index",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  qcor::execute("qasm3", uint_index, "uint_index");
}

TEST(qasm3VisitorTester, checkIfStmt) {
  const std::string if_stmt = R"#(OPENQASM 3;
include "qelib1.inc";
qubit q, s;//, qq[2];
const layers = 2;
bit cc[2];
qubit qq[2];

bit c;
c = measure q;
cc[0] = measure qq[0];

if (c == 1) {
    z s;
} else {
  print("c was a 0");
}

if (layers == 2) {
    print("should be here, layers is 2");
    z s;
} 


cc[1] = measure qq[1];
if ( cc[1] == 1) {
  ry(2.2) s;
}


)#";
  auto mlir = qcor::mlir_compile("qasm3", if_stmt, "if_stmt",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
}

TEST(qasm3VisitorTester, checkSecondIfStmt) {
  const std::string if_stmt = R"#(OPENQASM 3;
include "qelib1.inc";
qubit q, s, qqq[2];
bit c;

if (!c) {
 print("you should see me");
}
x q;
c = measure q;
if (c == 1) {
  print("hi");
  ry(2.2) s;
} 

c = measure qqq[0];
print("hi world");

)#";
  auto mlir = qcor::mlir_compile("qasm3", if_stmt, "if_stmt",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  qcor::execute("qasm3", if_stmt, "if_stmt");
}

TEST(qasm3VisitorTester, checkIfStmt3) {
  const std::string complex_if = R"#(OPENQASM 3;
include "qelib1.inc";
qubit q;
const n = 10;
int[32] i = 3;

bit temp;
if(temp==0 && i==3) {
  print("we are here"); 
  temp = measure q; 
}

)#";
  auto mlir = qcor::mlir_compile("qasm3", complex_if, "complex_if",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  qcor::execute("qasm3", complex_if, "complex_if");

}

TEST(qasm3VisitorTester, checkWhile) {
  const std::string while_stmt = R"#(OPENQASM 3;
include "qelib1.inc";
int[32] i = 0;
while (i < 10) {
  print(i);
  i += 1;
}
)#";
  auto mlir = qcor::mlir_compile("qasm3", while_stmt, "while_stmt",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  qcor::execute("qasm3", while_stmt, "while_stmt");
}
int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
  return ret;
}
+39 −168
Original line number Diff line number Diff line
program
  header
    version
      OPENQASM
      3
      ;
    include
      include
      "qelib1.inc"
      ;
  statement
    classicalDeclarationStatement
      classicalDeclaration
        singleDesignatorDeclaration
          singleDesignatorType
            int
          designator
            [
            expression
              expressionTerminator
                10
            ]
          identifierList
            x
            ,
            y
      ;
  statement
    classicalDeclarationStatement
      classicalDeclaration
        singleDesignatorDeclaration
          singleDesignatorType
            int
          designator
            [
            expression
              expressionTerminator
                10
            ]
          equalsAssignmentList
            xx
            equalsExpression
              =
      branchingStatement
        if
        (
        booleanExpression
          booleanExpression
            comparsionExpression
              expression
                expressionTerminator
                  2
            ,
            yy
            equalsExpression
              =
              expression
                  expressionTerminator
                  1
      ;
  globalStatement
    quantumDeclarationStatement
      quantumDeclaration
        quantumType
          qubit
        indexIdentifierList
          indexIdentifier
            q1
                    spec
                  [
            expressionList
                  expression
                    expressionTerminator
                  6
                      i
                  ]
          ,
          indexIdentifier
            q2
      ;
  statement
    classicalDeclarationStatement
      classicalDeclaration
        bitDeclaration
          bitType
            bit
          indexIdentifierList
            indexIdentifier
              k
      ;
  statement
    classicalDeclarationStatement
      classicalDeclaration
        bitDeclaration
          bitType
            bit
          indexIdentifierList
            indexIdentifier
              kk
              [
              expressionList
              relationalOperator
                ==
              expression
                expressionTerminator
                    10
              ]
      ;
  statement
    classicalDeclarationStatement
      classicalDeclaration
        bitDeclaration
          bitType
            bit
          indexEqualsAssignmentList
            indexIdentifier
              b1
              [
              expressionList
                  0
          logicalOperator
            &&
          comparsionExpression
            expression
              expressionTerminator
                    4
              ]
            equalsExpression
              =
              expression
                expressionTerminator
                  "0100"
            ,
            indexIdentifier
              b2
            equalsExpression
              =
              expression
                expressionTerminator
                  "1"
      ;
  statement
    classicalDeclarationStatement
      classicalDeclaration
        noDesignatorDeclaration
          noDesignatorType
            bool
          equalsAssignmentList
            m
            equalsExpression
              =
                  spec
                [
                expression
                  xOrExpression
                    bitAndExpression
                      bitShiftExpression
                        additiveExpression
                          additiveExpression
                            multiplicativeExpression
                              expressionTerminator
                  True
            ,
                                n
            equalsExpression
              =
              expression
                          +
                          multiplicativeExpression
                            expressionTerminator
                  builtInCall
                    castOperator
                      classicalType
                        noDesignatorType
                          bool
                    (
                    expressionList
                              i
                ]
            relationalOperator
              ==
            expression
              expressionTerminator
                          b2
                1
        )
 No newline at end of file
      ;
  statement
    classicalDeclarationStatement
      constantDeclaration
        const
        equalsAssignmentList
          c
          equalsExpression
            =
            expression
              expressionTerminator
                5.5e3
          ,
          d
          equalsExpression
            =
            expression
              expressionTerminator
                5
      ;
 No newline at end of file
+184 −15
Original line number Diff line number Diff line
@@ -36,6 +36,7 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
    // We have hit a closing on an index
    // std::cout << "TERMNODE:\n";
    indexed_variable_value = current_value;
    internal_value_type = builder.getIndexType();
  } else if (node->getSymbol()->getText() == "]") {
    if (casting_indexed_integer_to_bool) {
      // We have an indexed integer in indexed_variable_value
@@ -49,34 +50,53 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
      // uint[4] b_in = 15; // b = 1111
      // bool(b_in[1]);

      std::cout << "FIRST:\n";
      indexed_variable_value.dump();
      
      // auto number_value = builder.create<mlir::LoadOp>(location,
      // indexed_variable_value, get_or_create_constant_index_value(0,
      // location)); number_value.dump(); auto idx_minus_1 =
      // builder.create<mlir::SubIOp>(location, current_value,
      // get_or_create_constant_integer_value(1, location));
      auto bw = indexed_variable_value.getType().getIntOrFloatBitWidth();
      auto casted_idx = builder.create<mlir::IndexCastOp>(
          location, current_value, indexed_variable_value.getType());
      auto casted_idx =
          builder.create<mlir::IndexCastOp>(location, current_value,
                                            indexed_variable_value.getType()
                                                .cast<mlir::MemRefType>()
                                                .getElementType());
      auto load_value = builder.create<mlir::LoadOp>(
          location, indexed_variable_value,
          get_or_create_constant_index_value(0, location, 64, symbol_table,
                                             builder));
      auto shift = builder.create<mlir::UnsignedShiftRightOp>(
          location, indexed_variable_value.getType(), indexed_variable_value,
          casted_idx);
      // shift.dump();
          location, load_value, casted_idx);
      // auto shift_load_value = builder.create<mlir::LoadOp>(
      //     location, shift,
      //     get_or_create_constant_index_value(0, location, 64, symbol_table,
      //                                        builder));
      auto old_int_type = internal_value_type;
      internal_value_type = indexed_variable_value.getType();
      auto and_value = builder.create<mlir::AndOp>(
          location, shift,
          get_or_create_constant_integer_value(
              1, location, builder.getIntegerType(bw), symbol_table, builder));
          get_or_create_constant_integer_value(1, location,
                                               indexed_variable_value.getType()
                                                   .cast<mlir::MemRefType>()
                                                   .getElementType(),
                                               symbol_table, builder));
      internal_value_type = old_int_type;
      update_current_value(and_value.result());

      casting_indexed_integer_to_bool = false;
    } else {
      if (internal_value_type.isa<mlir::OpaqueType>() &&
          internal_value_type.cast<mlir::OpaqueType>().getTypeData().str() ==
              "Qubit") {
        update_current_value(builder.create<mlir::quantum::ExtractQubitOp>(
            location, get_custom_opaque_type("Qubit", builder.getContext()),
            indexed_variable_value, current_value));
      } else {
        // We are loading from a variable
        llvm::ArrayRef<mlir::Value> idx(current_value);
      update_current_value(
          builder.create<mlir::LoadOp>(location, indexed_variable_value, idx));
        update_current_value(builder.create<mlir::LoadOp>(
            location, indexed_variable_value, idx));
      }
    }
  }
  return 0;
@@ -86,7 +106,129 @@ antlrcpp::Any qasm3_expression_generator::visitExpression(
    qasm3Parser::ExpressionContext* ctx) {
  return visitChildren(ctx);
}
antlrcpp::Any qasm3_expression_generator::visitComparsionExpression(
    qasm3Parser::ComparsionExpressionContext* compare) {
  auto location = get_location(builder, file_name, compare);

  if (auto relational_op = compare->relationalOperator()) {
   
    visitChildren(compare->expression(0));
    auto lhs = current_value;
    visitChildren(compare->expression(1));
    auto rhs = current_value;

    // if lhs is memref of rank 1 and size 1, this is a
    // variable and we need to load its value
    auto lhs_type = lhs.getType();
    auto rhs_type = rhs.getType();
    if (auto mem_value_type = lhs_type.dyn_cast_or_null<mlir::MemRefType>()) {
      if (mem_value_type.getElementType().isIntOrIndex() &&
          mem_value_type.getRank() == 1 && mem_value_type.getShape()[0] == 1) {
        // Load this memref value

        lhs = builder.create<mlir::LoadOp>(
            location, lhs,
            get_or_create_constant_index_value(0, location, 64, symbol_table,
                                               builder));
      }
    }

    if (auto mem_value_type = rhs_type.dyn_cast_or_null<mlir::MemRefType>()) {
      if (mem_value_type.getElementType().isIntOrIndex() &&
          mem_value_type.getRank() == 1 && mem_value_type.getShape()[0] == 1) {
        // Load this memref value

        rhs = builder.create<mlir::LoadOp>(
            location, rhs,
            get_or_create_constant_index_value(0, location, 64, symbol_table,
                                               builder));
      }
    }

    auto op = relational_op->getText();
    if (antlr_to_mlir_predicate.count(op)) {
      // if so, get the mlir enum representing it
      auto predicate = antlr_to_mlir_predicate[op];

      auto lhs_bw = lhs.getType().getIntOrFloatBitWidth();
      auto rhs_bw = rhs.getType().getIntOrFloatBitWidth();
      // We need the comparison to be on the same bit width
      if (lhs_bw < rhs_bw) {
        rhs = builder.create<mlir::IndexCastOp>(location, rhs,
                                                builder.getIntegerType(lhs_bw));
      } else if (lhs_bw > rhs_bw) {
        lhs = builder.create<mlir::IndexCastOp>(location, lhs,
                                                builder.getIntegerType(rhs_bw));
      }

      // create the binary op value
      update_current_value(
          builder.create<mlir::CmpIOp>(location, predicate, lhs, rhs));
      return 0;
    } else {
      printErrorMessage("Invalid relational operation: " + op);
    }

  } else {
    // This is just if(expr)
    // printErrorMessage("Alex please implement if(expr).");

    found_negation_unary_op = false;
    visitChildren(compare->expression(0));
    // now just compare current_value to 1
    mlir::Type current_value_type =
        current_value.getType().isa<mlir::MemRefType>()
            ? current_value.getType().cast<mlir::MemRefType>().getElementType()
            : current_value.getType();

    current_value = builder.create<mlir::LoadOp>(
        location, current_value,
        get_or_create_constant_index_value(0, location, 64, symbol_table,
                                           builder));

    mlir::CmpIPredicate p = mlir::CmpIPredicate::eq;
    if (found_negation_unary_op) {
      p = mlir::CmpIPredicate::ne;
    }

    current_value = builder.create<mlir::CmpIOp>(
        location, p, current_value,
        get_or_create_constant_integer_value(1, location, current_value_type,
                                             symbol_table, builder));
    return 0;
  }
  return visitChildren(compare);
}

antlrcpp::Any qasm3_expression_generator::visitBooleanExpression(
    qasm3Parser::BooleanExpressionContext* ctx) {
  auto location = get_location(builder, file_name, ctx);

  if (ctx->logicalOperator()) {
    auto bool_expr = ctx->booleanExpression();
    visitChildren(bool_expr);
    auto lhs = current_value;

    visit(ctx->comparsionExpression());
    auto rhs = current_value;

    if (ctx->logicalOperator()->getText() == "&&") {
      update_current_value(builder.create<mlir::AndOp>(location, lhs, rhs));
      return 0;
    }
  }
  return visitChildren(ctx);
}

antlrcpp::Any qasm3_expression_generator::visitUnaryExpression(
    qasm3Parser::UnaryExpressionContext* ctx) {
  if (auto unary_op = ctx->unaryOperator()) {
    if (unary_op->getText() == "!") {
      found_negation_unary_op = true;
    }
  }
  return visitChildren(ctx);
}
// antlrcpp::Any qasm3_expression_generator::visitIncrementor(
//     qasm3Parser::IncrementorContext* ctx) {
//   auto location = get_location(builder, file_name, ctx);
@@ -142,6 +284,20 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression(
    visitChildren(ctx->multiplicativeExpression());
    auto rhs = current_value;

    if (lhs.getType().isa<mlir::MemRefType>()) {
      lhs = builder.create<mlir::LoadOp>(
          location, lhs,
          get_or_create_constant_index_value(0, location, 64, symbol_table,
                                             builder));
    }

    if (rhs.getType().isa<mlir::MemRefType>()) {
      rhs = builder.create<mlir::LoadOp>(
          location, rhs,
          get_or_create_constant_index_value(0, location, 64, symbol_table,
                                             builder));
    }

    if (bin_op == "+") {
      if (lhs.getType().isa<mlir::FloatType>() ||
          rhs.getType().isa<mlir::FloatType>()) {
@@ -448,10 +604,15 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
      if (no_desig_type && no_desig_type->getText() == "bool") {
        // We can cast these things to bool...
        auto expr = builtin->expressionList()->expression(0);
        // std::cout << "EXPR: " << expr->getText() << "\n";
        if (expr->getText().find("[") != std::string::npos) {
          casting_indexed_integer_to_bool = true;
        }
        visitChildren(expr);
        auto value_type = current_value.getType();
        // std::cout << "DUMP THIS:\n";
        // value_type.dump();
        // current_value.dump();
        if (auto mem_value_type =
                value_type.dyn_cast_or_null<mlir::MemRefType>()) {
          if (mem_value_type.getElementType().isIntOrIndex() &&
@@ -477,11 +638,19 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
            printErrorMessage("We can only cast integer types to bool. (" +
                              builtin->getText() + ").");
          }
        } else {
          // This is to catch things like bool(uint[i])
          current_value = builder.create<mlir::CmpIOp>(
              location, mlir::CmpIPredicate::eq, current_value,
              get_or_create_constant_integer_value(
                  1, location, current_value.getType(), symbol_table, builder));
          return 0;
        }
      }
    }

    printErrorMessage("We only support bool() cast operations.");
    printErrorMessage(
        "We only support bool(int|uint|uint[i]) cast operations.");

  } else {
    printErrorMessage("Cannot handle this expression terminator yet: " +
Loading