Commit 9f6a00ee authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

update for loop to use declared variables as loop indices, minor bug fixes,...


update for loop to use declared variables as loop indices, minor bug fixes, add tests for control directives and superposition demo, add deuteron ftqc example

Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent e197af57
Loading
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line

OPENQASM 3;
include "stdgates.inc";

const shots = 1024;

def deuteron(float[64]:theta) qubit[2]:q -> float[64] {
    bit first, second;
    float[64] num_parity_ones = 0.0;
    float[64] result;
    for i in [0:shots] {
        x q[0];
        ry(theta) q[1];
        cx q[1], q[0];

        h q;

        first = measure q[0];
        second = measure q[1];

        if (first != second) {
            num_parity_ones += 1.0;
        }

        reset q;
    }

    // Compute expectation value
    result = (shots - num_parity_ones) / shots - num_parity_ones / shots;
    return result;
}

float[64] theta, result, avg;
qubit qq[2];

int[32] n_trials = 10;
for i in [0:n_trials] {
  result = deuteron(theta) qq;
  avg += result;
  print("<X0X1> = ", result, avg);
}

avg /= n_trials;
print("Avg <X0X1> = ", avg);
 No newline at end of file
+11 −0
Original line number Diff line number Diff line
@@ -58,3 +58,14 @@ add_executable(qasm3CompilerTester_Arithmetic test_complex_arithmetic.cpp)
add_test(NAME qcor_qasm3_test_arithmetic COMMAND qasm3CompilerTester_Arithmetic)
target_include_directories(qasm3CompilerTester_Arithmetic PRIVATE . ../../ ${XACC_ROOT}/include/gtest)
target_link_libraries(qasm3CompilerTester_Arithmetic qcor-mlir-api gtest gtest_main)


add_executable(qasm3CompilerTester_ControlDirectives test_control_directives.cpp)
add_test(NAME qcor_qasm3_test_control_directives COMMAND qasm3CompilerTester_ControlDirectives)
target_include_directories(qasm3CompilerTester_ControlDirectives PRIVATE . ../../ ${XACC_ROOT}/include/gtest)
target_link_libraries(qasm3CompilerTester_ControlDirectives qcor-mlir-api gtest gtest_main)

add_executable(qasm3CompilerTester_Superposition test_superposition.cpp)
add_test(NAME qcor_qasm3_test_superposition COMMAND qasm3CompilerTester_Superposition)
target_include_directories(qasm3CompilerTester_Superposition PRIVATE . ../../ ${XACC_ROOT}/include/gtest)
target_link_libraries(qasm3CompilerTester_Superposition qcor-mlir-api gtest gtest_main)
+41 −0
Original line number Diff line number Diff line
#include "gtest/gtest.h"
#include "qcor_mlir_api.hpp"

TEST(qasm3VisitorTester, checkCtrlDirectives) {
  const std::string uint_index = R"#(OPENQASM 3;
include "qelib1.inc";

int[64] iterate_value = 0;
int[64] hit_continue_value = 0;
for i in [0:10] {
    iterate_value = i;
    if (i == 5) {
        print("breaking at 5");
        break;
    }
    if (i == 2) {
        hit_continue_value = i;
        print("continuing at 2");
        continue;
    }
    print("i = ", i);
}

QCOR_EXPECT_TRUE(iterate_value == 5);
QCOR_EXPECT_TRUE(hit_continue_value == 2);

print("made it out of the loop");

)#";
  auto mlir = qcor::mlir_compile("qasm3", uint_index, "uint_index",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  EXPECT_FALSE(qcor::execute("qasm3", uint_index, "uint_index"));
}


int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
  return ret;
}
 No newline at end of file
+43 −0
Original line number Diff line number Diff line
#include "gtest/gtest.h"
#include "qcor_mlir_api.hpp"

TEST(qasm3VisitorTester, checkSuperposition) {
  const std::string uint_index = R"#(OPENQASM 3;
include "qelib1.inc";
qubit q;
bit c;

const shots = 1024;
int[32] ones = 0;
int[32] zeros = 0;

for i in [0:shots] {
  h q;
  c = measure q;
  if (c == 1) {
   ones += 1;
  } else {
   zeros += 1;
  }
  reset q;
}

print("N |1> measured = ", ones);
print("N |0> measured = ", zeros);

// give the randomness a bit of wriggle room
QCOR_EXPECT_TRUE(ones > 490);
QCOR_EXPECT_TRUE(zeros > 490);
)#";
  auto mlir = qcor::mlir_compile("qasm3", uint_index, "uint_index",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir << "\n";
  EXPECT_FALSE(qcor::execute("qasm3", uint_index, "uint_index"));
}


int main(int argc, char **argv) {
  ::testing::InitGoogleTest(&argc, argv);
  auto ret = RUN_ALL_TESTS();
  return ret;
}
 No newline at end of file
+39 −56
Original line number Diff line number Diff line
@@ -362,35 +362,11 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression(
        // One of these at least is a float, need to have
        // both as float
        if (!lhs.getType().isa<mlir::FloatType>()) {
          if (auto op = lhs.getDefiningOp<mlir::ConstantOp>()) {
            auto value = op.getValue()
                             .cast<mlir::IntegerAttr>()
                             .getValue()
                             .getLimitedValue();
            lhs = builder.create<mlir::ConstantOp>(
                location, mlir::FloatAttr::get(rhs.getType(), (double)value));
          } else {
            printErrorMessage("Must cast lhs to float, but it is not constant.",
                              ctx, {lhs, rhs});
          }
          lhs = builder.create<mlir::SIToFPOp>(location, lhs, rhs.getType());

        } else if (!rhs.getType().isa<mlir::FloatType>()) {
          if (auto op = rhs.getDefiningOp<mlir::ConstantOp>()) {
            auto value = op.getValue()
                             .cast<mlir::IntegerAttr>()
                             .getValue()
                             .getLimitedValue();
            rhs = builder.create<mlir::ConstantOp>(
                location, mlir::FloatAttr::get(lhs.getType(), (double)value));
          } else {
            printErrorMessage("Must cast rhs to float, but it is not constant.",
                              ctx, {lhs, rhs});
          }
          rhs = builder.create<mlir::SIToFPOp>(location, rhs, lhs.getType());
        }
        // else {
        //   printErrorMessage(
        //       "Could not perform subtraction, incompatible types: " +
        //       ctx->getText());
        // }

        createOp<mlir::SubFOp>(location, lhs, rhs);
      } else if (lhs.getType().isa<mlir::IntegerType>() &&
@@ -488,13 +464,11 @@ antlrcpp::Any qasm3_expression_generator::visitXOrExpression(

      mlir::Value loop_var_memref = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(shaperef, builder.getI64Type()));
      builder.create<mlir::StoreOp>(location, ret2,
                                    loop_var_memref);  
      builder.create<mlir::StoreOp>(location, ret2, loop_var_memref);

      mlir::Value product_memref = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(shaperef, lhs_element_type));
      builder.create<mlir::StoreOp>(location, ret3,
                                    product_memref); 
      builder.create<mlir::StoreOp>(location, ret3, product_memref);

      auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), 1);
      auto ret = builder.create<mlir::ConstantOp>(location, integer_attr);
@@ -600,7 +574,6 @@ antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression(
        // One of these at least is a float, need to have
        // both as float
        if (!lhs.getType().isa<mlir::FloatType>()) {

          lhs = builder.create<mlir::SIToFPOp>(location, lhs, rhs.getType());

        } else if (!rhs.getType().isa<mlir::FloatType>()) {
@@ -622,31 +595,41 @@ antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression(
        // One of these at least is a float, need to have
        // both as float
        if (!lhs.getType().isa<mlir::FloatType>()) {
          if (auto op = lhs.getDefiningOp<mlir::ConstantOp>()) {
            auto value = op.getValue()
                             .cast<mlir::IntegerAttr>()
                             .getValue()
                             .getLimitedValue();
            lhs = builder.create<mlir::ConstantOp>(
                location, mlir::FloatAttr::get(rhs.getType(), (double)value));
          } else {
            printErrorMessage(
                "Must cast lhs to float, but it is not constant.");
          }
          lhs = builder.create<mlir::SIToFPOp>(location, lhs, rhs.getType());

        } else if (!rhs.getType().isa<mlir::FloatType>()) {
          if (auto op = rhs.getDefiningOp<mlir::ConstantOp>()) {
            auto value = op.getValue()
                             .cast<mlir::IntegerAttr>()
                             .getValue()
                             .getLimitedValue();
            rhs = builder.create<mlir::ConstantOp>(
                location, mlir::FloatAttr::get(lhs.getType(), (double)value));
          } else {
            printErrorMessage(
                "Must cast rhs to float, but it is not constant.");
          }
          rhs = builder.create<mlir::SIToFPOp>(location, rhs, lhs.getType());
        }

        // if (!lhs.getType().isa<mlir::FloatType>()) {
        //   if (auto op = lhs.getDefiningOp<mlir::ConstantOp>()) {
        //     auto value = op.getValue()
        //                      .cast<mlir::IntegerAttr>()
        //                      .getValue()
        //                      .getLimitedValue();
        //     lhs = builder.create<mlir::ConstantOp>(
        //         location, mlir::FloatAttr::get(rhs.getType(),
        //         (double)value));
        //   } else {
        //     printErrorMessage(
        //         "Must cast lhs to float, but it is not constant.");
        //   }
        // } else if (!rhs.getType().isa<mlir::FloatType>()) {
        //   if (auto op = rhs.getDefiningOp<mlir::ConstantOp>()) {
        //     auto value = op.getValue()
        //                      .cast<mlir::IntegerAttr>()
        //                      .getValue()
        //                      .getLimitedValue();
        //     rhs = builder.create<mlir::ConstantOp>(
        //         location, mlir::FloatAttr::get(lhs.getType(),
        //         (double)value));
        //   } else {
        //     printErrorMessage(
        //         "Must cast rhs to float, but it is not constant.", ctx, {lhs,
        //         rhs});
        //   }
        // }

        createOp<mlir::DivFOp>(location, lhs, rhs);
      } else if (lhs.getType().isa<mlir::IntegerType>() &&
                 rhs.getType().isa<mlir::IntegerType>()) {
Loading