Commit efbe54e8 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

adding quantum random number generator example, plus bug fixes / enhancements to get it running

parent abef62b6
Loading
Loading
Loading
Loading
Loading
+43 −0
Original line number Diff line number Diff line
OPENQASM 3;

// Global constant, maximum bit size 
// for the random integer
const max_bits = 4;

// Generate a superposition and 
// measure to return a 50/50 random bit
def random_bit() qubit:a -> bit {
    h a;
    return measure a;
}

// Generate a random integer of max_bits bit width
// This will generate a random 0 or 1 
// based on a single provided qubit put 
// in a superposition
def generate_random_int() qubit:q -> int {
    // Create [0,0,0,...0] of size max_bits
    bit b[max_bits];

    // Set every bit as a random 0 or 1
    for i in [0:max_bits] {
        b[i] = random_bit() q;
        // reset qubit state for 
        // next iteration
        reset q;
    }
    // Print the binary string
    print("random binary: ", b);
    int n = int[32](b);
    return n;
}

// Allocate a single qubit
qubit a;

// Generate the random number 
// using the allocated qubit
int n = generate_random_int() a;

// print the random number
print("Random int (lsb): ", n);
 No newline at end of file
+13 −0
Original line number Diff line number Diff line
@@ -44,6 +44,8 @@ for i in [0:4] {
 }
}
QCOR_EXPECT_TRUE(loop_count == 12);


)#";
  auto mlir = qcor::mlir_compile("qasm3", for_stmt, "for_stmt",
                                 qcor::OutputType::MLIR, false);
@@ -63,6 +65,17 @@ QCOR_EXPECT_TRUE(i == 10);
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir2 << "\n";
  EXPECT_FALSE(qcor::execute("qasm3", while_stmt, "while_stmt"));

    const std::string decrement = R"#(OPENQASM 3;
include "qelib1.inc";
for j in [10:-1:0] {
  print(j);
}
)#";
  auto mlir3 = qcor::mlir_compile("qasm3", decrement, "decrement",
                                 qcor::OutputType::MLIR, false);
  std::cout << mlir3 << "\n";
  EXPECT_FALSE(qcor::execute("qasm3", decrement, "decrement"));
}

int main(int argc, char **argv) {
+25 −14
Original line number Diff line number Diff line
@@ -214,7 +214,8 @@ antlrcpp::Any qasm3_visitor::visitNoDesignatorDeclaration(
      // Save the allocation, the store op
      symbol_table.add_symbol(variable, allocation);
    }
  } else if (context->noDesignatorType()->getText().find("int") != std::string::npos) {
  } else if (context->noDesignatorType()->getText().find("int") !=
             std::string::npos) {
    // THis can now be either an identifierList or an equalsAssignementList
    mlir::Attribute init_attr;
    mlir::Type value_type;
@@ -353,8 +354,7 @@ antlrcpp::Any qasm3_visitor::visitNoDesignatorDeclaration(
      // Save the allocation, the store op
      symbol_table.add_symbol(variable, allocation);
    }
  }
  else {
  } else {
    printErrorMessage("We do not yet support this no designator type: " +
                          context->noDesignatorType()->getText(),
                      context);
@@ -523,10 +523,20 @@ antlrcpp::Any qasm3_visitor::visitClassicalAssignment(
  // bit = subroutine_call(params) qbits...
  if (auto call_op = rhs.getDefiningOp<mlir::CallOp>()) {
    int bit_idx = 0;
    bool we_have_lhs_idx = false;
    mlir::Value v;
    if (auto index_list = context->indexIdentifier(0)->expressionList()) {
      // Need to extract element from bit array to set it
      auto idx_str = index_list->expression(0)->getText();
      bit_idx = std::stoi(idx_str);
      // auto idx_str = index_list->expression(0)->getText();
      // bit_idx = std::stoi(idx_str);
      we_have_lhs_idx = true;
      qasm3_expression_generator equals_exp_generator(builder, symbol_table,
                                                      file_name);
      equals_exp_generator.visit(index_list->expression(0));
      v = equals_exp_generator.current_value;
    } else {
      v = get_or_create_constant_index_value(0, location, 64, symbol_table,
                                             builder);
    }

    // Scenarios:
@@ -556,15 +566,14 @@ antlrcpp::Any qasm3_visitor::visitClassicalAssignment(
              llvm::makeArrayRef(std::vector<mlir::Value>{pos}));
        }
      } else {
        if (lhs_shape != 1) {
        if (lhs_shape != 1 && !we_have_lhs_idx) {
          printErrorMessage("rhs and lhs memref shapes do not match.", context,
                            {lhs, rhs});
        }
        mlir::Value pos = get_or_create_constant_integer_value(
            0, location, builder.getIntegerType(64), symbol_table, builder);

        builder.create<mlir::StoreOp>(
            location, rhs, lhs,
            llvm::makeArrayRef(std::vector<mlir::Value>{pos}));
            llvm::makeArrayRef(std::vector<mlir::Value>{v}));
      }
    } else {
      builder.create<mlir::StoreOp>(location, rhs, lhs);
@@ -682,9 +691,11 @@ antlrcpp::Any qasm3_visitor::visitClassicalAssignment(
  } else if (assignment_op == "^=") {
    current_value =
        builder.create<mlir::XOrOp>(location, load_result, load_result_rhs);
    // llvm::ArrayRef<mlir::Value> zero_index2(get_or_create_constant_index_value(
        // 0, location, 64, symbol_table, builder));
    builder.create<mlir::StoreOp>(location, current_value, lhs);//, zero_index2);
    // llvm::ArrayRef<mlir::Value>
    // zero_index2(get_or_create_constant_index_value( 0, location, 64,
    // symbol_table, builder));
    builder.create<mlir::StoreOp>(location, current_value,
                                  lhs);  //, zero_index2);
  } else if (assignment_op == "=") {
    // FIXME This assumes we have a memref<1x??> = memref<1x??>
    // what if we have multiple elements in the memref???
+30 −52
Original line number Diff line number Diff line
@@ -170,42 +170,32 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
          c_value = get_or_create_constant_integer_value(c, location, int_type,
                                                         symbol_table, builder);

      if (symbol_table.has_symbol(range->expression(0)->getText())) {
        a_value = symbol_table.get_symbol(range->expression(0)->getText());
      qasm3_expression_generator exp_generator(builder, symbol_table,
                                               file_name);
      exp_generator.visit(range->expression(0));
      a_value = exp_generator.current_value;
      if (a_value.getType().isa<mlir::MemRefType>()) {
        a_value = builder.create<mlir::LoadOp>(location, a_value);
        if (a_value.getType() != int_type) {
          printErrorMessage("For loop a, b, and c types are not equal.",
                            context, {a_value, c_value});
        }
      } else {
        // infer the constant from the symbol table
        a = symbol_table.evaluate_constant_integer_expression(
            range->expression(0)->getText());
        a_value = get_or_create_constant_integer_value(a, location, int_type,
                                                       symbol_table, builder);
      }

      if (n_expr == 3) {
        if (symbol_table.has_symbol(range->expression(2)->getText())) {
          b_value = symbol_table.get_symbol(range->expression(2)->getText());
        qasm3_expression_generator exp_generator(builder, symbol_table,
                                                 file_name);
        exp_generator.visit(range->expression(2));
        b_value = exp_generator.current_value;
        if (b_value.getType().isa<mlir::MemRefType>()) {
          b_value = builder.create<mlir::LoadOp>(location, b_value);
          if (b_value.getType() != int_type) {
            printErrorMessage("For loop a, b, and c types are not equal.",
                              context, {a_value, b_value});
          }
        } else {
          b = symbol_table.evaluate_constant_integer_expression(
              range->expression(2)->getText());
          b_value = get_or_create_constant_integer_value(
              b, location, a_value.getType(), symbol_table, builder);
        }

        if (symbol_table.has_symbol(range->expression(1)->getText())) {
          c_value = symbol_table.get_symbol(range->expression(1)->getText());
          c_value = builder.create<mlir::LoadOp>(location, c_value);
          if (c_value.getType() != int_type) {
            printErrorMessage("For loop a, b, and c types are not equal.",
                              context, {a_value, c_value});
          }
          printErrorMessage("You must provide loop step as a constant value.",
                            context);
          // c_value = symbol_table.get_symbol(range->expression(1)->getText());
          // c_value = builder.create<mlir::LoadOp>(location, c_value);
          // if (c_value.getType() != int_type) {
          //   printErrorMessage("For loop a, b, and c types are not equal.",
          //                     context, {a_value, c_value});
          // }
        } else {
          c = symbol_table.evaluate_constant_integer_expression(
              range->expression(1)->getText());
@@ -214,14 +204,6 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        }

      } else {
        // if (symbol_table.has_symbol(range->expression(1)->getText())) {
        //   b_value = symbol_table.get_symbol(range->expression(1)->getText());
        //   b_value = builder.create<mlir::LoadOp>(location, b_value);
        //   if (b_value.getType() != int_type) {
        //     printErrorMessage("For loop a, b, and c types are not equal.",
        //                       context, {a_value, b_value});
        //   }
        // } else {
        qasm3_expression_generator exp_generator(builder, symbol_table,
                                                 file_name);
        exp_generator.visit(range->expression(1));
@@ -229,12 +211,6 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        if (b_value.getType().isa<mlir::MemRefType>()) {
          b_value = builder.create<mlir::LoadOp>(location, b_value);
        }

          // b = symbol_table.evaluate_constant_integer_expression(
          //     range->expression(1)->getText());
          // b_value = get_or_create_constant_integer_value(
          //     b, location, a_value.getType(), symbol_table, builder);
        // }
      }

      // Create a new scope for the for loop
@@ -278,7 +254,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

      auto load = builder.create<mlir::LoadOp>(location, loop_var_memref);
      auto cmp = builder.create<mlir::CmpIOp>(
          location, mlir::CmpIPredicate::slt, load, b_value);
          location, c > 0 ? mlir::CmpIPredicate::slt : mlir::CmpIPredicate::sge, load, b_value);
      builder.create<mlir::CondBranchOp>(location, cmp, bodyBlock, exitBlock);

      builder.setInsertionPointToStart(bodyBlock);
@@ -299,8 +275,10 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

      builder.setInsertionPointToStart(incBlock);
      auto load_inc = builder.create<mlir::LoadOp>(location, loop_var_memref);

        auto add = builder.create<mlir::AddIOp>(location, load_inc, c_value);


      builder.create<mlir::StoreOp>(location, add, loop_var_memref);

      builder.create<mlir::BranchOp>(location, headerBlock);
+9 −8
Original line number Diff line number Diff line
@@ -201,21 +201,22 @@ antlrcpp::Any qasm3_visitor::visitQuantumMeasurementAssignment(
          llvm::makeArrayRef(std::vector<mlir::Value>{}));

      // Get the bit or bit[]
      int bit_idx = 0;
      mlir::Value v;
      if (auto index_list =
              indexIdentifierList->indexIdentifier(0)->expressionList()) {
        // Need to extract element from bit array to set it
        auto idx_str = index_list->expression(0)->getText();
        bit_idx = std::stoi(idx_str);
        qasm3_expression_generator equals_exp_generator(builder, symbol_table,
                                                        file_name);
        equals_exp_generator.visit(index_list->expression(0));
        v = equals_exp_generator.current_value;
      } else {
        v = get_or_create_constant_index_value(
          0, location, 64, symbol_table, builder);
      }

      // Store the mz result into the bit_value
      mlir::Value pos = get_or_create_constant_index_value(
          bit_idx, location, 64, symbol_table, builder);

      builder.create<mlir::StoreOp>(
          location, instop.bit(), bit_value,
          llvm::makeArrayRef(std::vector<mlir::Value>{pos}));
          llvm::makeArrayRef(std::vector<mlir::Value>{v}));
    } else {
      // This is the case where we are measuring an entire qubit array
      // to a bit array
Loading