Commit b6792765 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

bug fixes and implementation to support qpe.qasm example using ctrl and pow...


bug fixes and implementation to support qpe.qasm example using ctrl and pow gate modifiers, plus a iqft subroutine.

Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 6323223a
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -73,7 +73,7 @@ def StartPowURegion : QuantumOp<"start_pow_u_region", []> {
}

def EndPowURegion : QuantumOp<"end_pow_u_region", []> {
  let arguments = (ins AnyI64Attr:$pow);
  let arguments = (ins AnyI64:$pow);
  let results = (outs);
}

+64 −0
Original line number Diff line number Diff line
OPENQASM 3;

const n_counting = 3;

// For this example, the oracle is the T gate 
// on the provided qubit
gate oracle b {
    t b;
}

// Inverse QFT subroutine on n_counting qubits
def iqft qubit[n_counting]:qq {
    for i in [0:n_counting/2] {
        swap qq[i], qq[n_counting-i-1];
    }
    for i in [0:n_counting-1] {
        h qq[i];
        int j = i + 1;
        int y = i;
        while (y >= 0) {
            double theta = -pi / (2^(j-y));
            cphase(theta) qq[j], qq[y];
            y -= 1;
        }
    }
    h qq[n_counting-1];
}

// Define some counting qubits
qubit counting[n_counting];

// Allocate the qubit we'll 
// put the initial state on
qubit state;

// We want T |1> = exp(2*i*pi*phase) |1> = exp(i*pi/4)
// compute phase, should be 1 / 8;

// Initialize to |1>
x state;

// Put all others in a uniform superposition
h counting;

// Loop over and create ctrl-U**2k
int repetitions = 1;
for i in [0:n_counting] {
    print("i is ", i, repetitions);
    ctrl @ pow(repetitions) @ oracle counting[i], state;
    repetitions *= 2;
}

// Run inverse QFT 
iqft counting;

// Now lets measure the counting qubits
bit c[n_counting];
measure counting -> c;

// Backend is QPP which is lsb, 
// so return should be 100
for i in [0:n_counting]{
    print(c[i]);
}
+22 −72
Original line number Diff line number Diff line
@@ -86,18 +86,18 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
          internal_value_type.cast<mlir::OpaqueType>().getTypeData().str() ==
              "Qubit") {
        if (current_value.getType().isa<mlir::MemRefType>()) {
          if (current_value.getType().cast<mlir::MemRefType>().getRank() == 1 &&
              current_value.getType().cast<mlir::MemRefType>().getShape()[0] ==
                  1) {
            current_value = builder.create<mlir::LoadOp>(
                location, current_value,
                get_or_create_constant_index_value(0, location, 64,
                                                   symbol_table, builder));
          if (current_value.getType().cast<mlir::MemRefType>().getRank() == 0) {
            current_value =
                builder.create<mlir::LoadOp>(location, current_value);
          } else {
            printErrorMessage("Terminator ']' -> Invalid qubit array index: ",
                              current_value);
          }
        }
        if (current_value.getType().getIntOrFloatBitWidth() < 64) {
          current_value = builder.create<mlir::ZeroExtendIOp>(
              location, current_value, builder.getI64Type());
        }
        update_current_value(builder.create<mlir::quantum::ExtractQubitOp>(
            location, get_custom_opaque_type("Qubit", builder.getContext()),
            indexed_variable_value, current_value));
@@ -171,11 +171,11 @@ antlrcpp::Any qasm3_expression_generator::visitComparsionExpression(

        // We need the comparison to be on the same bit width
        if (lhs_bw < rhs_bw) {
          rhs = builder.create<mlir::IndexCastOp>(
              location, rhs, builder.getIntegerType(lhs_bw));
        } else if (lhs_bw > rhs_bw) {
          lhs = builder.create<mlir::IndexCastOp>(
          lhs = builder.create<mlir::ZeroExtendIOp>(
              location, lhs, builder.getIntegerType(rhs_bw));
        } else if (lhs_bw > rhs_bw) {
          rhs = builder.create<mlir::ZeroExtendIOp>(
              location, rhs, builder.getIntegerType(lhs_bw));
        }

        // create the binary op value
@@ -453,7 +453,7 @@ antlrcpp::Any qasm3_expression_generator::visitXOrExpression(
          1, location, lhs_element_type, symbol_table, builder);
      llvm::ArrayRef<mlir::Value> zero_index(tmp2);

      llvm::ArrayRef<int64_t> shaperef{0};
      llvm::ArrayRef<int64_t> shaperef{};
      auto mem_type = mlir::MemRefType::get(shaperef, lhs_element_type);

      auto integer_attr2 = mlir::IntegerAttr::get(lhs_element_type, 0);
@@ -492,7 +492,14 @@ antlrcpp::Any qasm3_expression_generator::visitXOrExpression(
      builder.setInsertionPointToStart(headerBlock);

      auto load = builder.create<mlir::LoadOp>(
          location, loop_var_memref);  //, zero_index);
          location, loop_var_memref).result();  //, zero_index);
      
      if (load.getType().getIntOrFloatBitWidth() < b_val.getType().getIntOrFloatBitWidth()) {
        load = builder.create<mlir::ZeroExtendIOp>(location, load, b_val.getType());
      } else if (b_val.getType().getIntOrFloatBitWidth() < load.getType().getIntOrFloatBitWidth()) {
        b_val = builder.create<mlir::ZeroExtendIOp>(location, b_val, load.getType());
      }

      auto cmp = builder.create<mlir::CmpIOp>(
          location, mlir::CmpIPredicate::slt, load, b_val);
      builder.create<mlir::CondBranchOp>(location, cmp, bodyBlock, exitBlock);
@@ -1123,9 +1130,7 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
    // If RHS is a memref<1xTYPE> then lets load it first
    if (auto rhs_mem = call_op.getType().dyn_cast_or_null<mlir::MemRefType>()) {
      call_op = builder.create<mlir::LoadOp>(
          location, call_op,
          get_or_create_constant_index_value(0, location, 64, symbol_table,
                                             builder));
          location, call_op);
    }
    // printErrorMessage("HELLO should we return the loaded result here?", ctx,
    // {call_op.getResult(0)});
@@ -1166,58 +1171,3 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
}

}  // namespace qcor

/*keeping in case i need later
for (int j = 0; j < bit_width; j++) {
                auto j_val = get_or_create_constant_integer_value(
                    j, location, builder.getI32Type(), symbol_table, builder);
                auto load_bit_j =
                    builder.create<mlir::LoadOp>(location, var_to_cast, j_val);
                // Extend i1 to the same width as i
                auto load_j_ext = builder.create<mlir::ZeroExtendIOp>(
                    location, load_bit_j, int_value_type);

                // Negate bits[j] to get -bit[j]`
                auto neg_load_j = builder.create<mlir::SubIOp>(
                    location,
                    builder.create<mlir::ConstantOp>(location, init_attr),
                    load_j_ext);

                // load the current value of i
                auto load_i = builder.create<mlir::LoadOp>(
                    location, init_allocation,
                    get_or_create_constant_index_value(0, location, bit_width,
                                                       symbol_table, builder));

                // first = -bits[j] ^ i
                auto xored_val =
                    builder.create<mlir::XOrOp>(location, neg_load_j, load_i);

                // (1 << j)
                // create j integer index
                // auto j_val = get_or_create_constant_integer_value(
                //     j, location, int_value_type, symbol_table, builder);
                // second = (1 << j)
                j_val = builder.create<mlir::TruncateIOp>(location, j_val,
              int_value_type); auto shift_left_val =
              builder.create<mlir::ShiftLeftOp>( location,
                    get_or_create_constant_integer_value(
                        1, location, int_value_type, symbol_table, builder),
                    j_val);

                // (-bits[j] ^ i) & (1 << j)
                auto result = builder.create<mlir::AndOp>(location, xored_val,
                                                          shift_left_val);

                auto load_i2 = builder.create<mlir::LoadOp>(
                    location, init_allocation,
                    get_or_create_constant_index_value(0, location, bit_width,
                                                       symbol_table, builder));
                auto result_to_store =
                    builder.create<mlir::XOrOp>(location, load_i2, result);

                auto val = builder.create<mlir::StoreOp>(
                    location, result_to_store, init_allocation,
                    get_or_create_constant_index_value(0, location, 64,
                                                       symbol_table, builder));
              }*/
 No newline at end of file
+3 −3
Original line number Diff line number Diff line
@@ -678,9 +678,9 @@ antlrcpp::Any qasm3_visitor::visitClassicalAssignment(
  } else if (assignment_op == "^=") {
    current_value =
        builder.create<mlir::XOrOp>(location, load_result, load_result_rhs);
    llvm::ArrayRef<mlir::Value> zero_index2(get_or_create_constant_index_value(
        0, location, 64, symbol_table, builder));
    builder.create<mlir::StoreOp>(location, current_value, lhs, zero_index2);
    // llvm::ArrayRef<mlir::Value> zero_index2(get_or_create_constant_index_value(
        // 0, location, 64, symbol_table, builder));
    builder.create<mlir::StoreOp>(location, current_value, lhs);//, zero_index2);
  } else if (assignment_op == "=") {
    // FIXME This assumes we have a memref<1x??> = memref<1x??>
    // what if we have multiple elements in the memref???
+21 −13
Original line number Diff line number Diff line
@@ -214,19 +214,27 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        }

      } else {
        if (symbol_table.has_symbol(range->expression(1)->getText())) {
          b_value = symbol_table.get_symbol(range->expression(1)->getText());
        // if (symbol_table.has_symbol(range->expression(1)->getText())) {
        //   b_value = symbol_table.get_symbol(range->expression(1)->getText());
        //   b_value = builder.create<mlir::LoadOp>(location, b_value);
        //   if (b_value.getType() != int_type) {
        //     printErrorMessage("For loop a, b, and c types are not equal.",
        //                       context, {a_value, b_value});
        //   }
        // } else {
          qasm3_expression_generator exp_generator(builder, symbol_table,
                                                   file_name);
          exp_generator.visit(range->expression(1));
          b_value = exp_generator.current_value;
          if (b_value.getType().isa<mlir::MemRefType>()) {
            b_value = builder.create<mlir::LoadOp>(location, b_value);
          if (b_value.getType() != int_type) {
            printErrorMessage("For loop a, b, and c types are not equal.",
                              context, {a_value, b_value});
          }
        } else {
          b = symbol_table.evaluate_constant_integer_expression(
              range->expression(1)->getText());
          b_value = get_or_create_constant_integer_value(
              b, location, a_value.getType(), symbol_table, builder);
          }

          // b = symbol_table.evaluate_constant_integer_expression(
          //     range->expression(1)->getText());
          // b_value = get_or_create_constant_integer_value(
          //     b, location, a_value.getType(), symbol_table, builder);
        // }
      }

      // Create a new scope for the for loop
Loading