Commit eed8cabd authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Using consistent types in Quantum Dialect



We must make sure the types at the MLIR level self-consistent to be able to operate passes at the MLIR level:

- Measure => return Result

- Add a cast op to convert to i1 (bool)

- Proper use of I64 for q.extract (doesn't matter if we lower to LLVM, but will complain at MLIR level)

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 6d437daa
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -152,4 +152,11 @@ def CreateStringLiteralOp : QuantumOp<"createString", []> {
  p << "q.create_string(\"" << op.text() << "\")"; }];
}

// Cast QIR Result to bool (i1 type)
def ResultCastOp : QuantumOp<"resultCast", []> {
    let arguments = (ins ResultType:$measure_result);
    let results = (outs I1:$bit_result);
    let printer = [{  auto op = *this;
  p << "q.resultCast" << "(" << op.measure_result() << ") : " << op.bit_result().getType(); }];
}
#endif // Quantum_OPS
 No newline at end of file
+1 −1
Original line number Diff line number Diff line
@@ -38,7 +38,7 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
    mlir::Identifier dialect = mlir::Identifier::get("quantum", context);
    qubit_type = mlir::OpaqueType::get(context, dialect, qubit_type_name);
    array_type = mlir::OpaqueType::get(context, dialect, array_type_name);
    result_type = mlir::IntegerType::get(context, 1);
    result_type = mlir::OpaqueType::get(context, dialect, result_type_name);
    symbol_table.set_op_builder(builder);
  }

+14 −4
Original line number Diff line number Diff line
@@ -166,8 +166,11 @@ antlrcpp::Any qasm3_visitor::visitQuantumMeasurementAssignment(
          auto qbit_idx = qubit_indices[i];
          auto bit_idx = bit_indices[i];

          mlir::Value idx_val = get_or_create_constant_index_value(
              qbit_idx, location, 64, symbol_table, builder);
          // !IMPORTANT! q.extract expects i64 as index.
          // Using index type will cause validation issue at the MLIR level.
          // (i.e. requires all-the-way-to-LLVM lowering for types to match)
          mlir::Value idx_val = get_or_create_constant_integer_value(
              qbit_idx, location, builder.getI64Type(), symbol_table, builder);
          mlir::Value bit_idx_val = get_or_create_constant_index_value(
              bit_idx, location, 64, symbol_table, builder);

@@ -251,18 +254,25 @@ antlrcpp::Any qasm3_visitor::visitQuantumMeasurementAssignment(
      }

      for (int i = 0; i < nqubits; i++) {
        // q.Extract must use integer type (not index type)
        mlir::Value q_idx_val = get_or_create_constant_integer_value(
            i, location, builder.getI64Type(), symbol_table, builder);
        mlir::Value idx_val = get_or_create_constant_index_value(
            i, location, 64, symbol_table, builder);

        auto extract = builder.create<mlir::quantum::ExtractQubitOp>(
            location, qubit_type, value, idx_val);
            location, qubit_type, value, q_idx_val);

        auto instop = builder.create<mlir::quantum::InstOp>(
            location, result_type, str_attr, llvm::makeArrayRef(extract.qbit()),
            llvm::makeArrayRef(std::vector<mlir::Value>{}));
        
        // Cast Measure Result -> Bit (i1)
        auto cast_bit_op = builder.create<mlir::quantum::ResultCastOp>(
            location, builder.getIntegerType(1), instop.bit());

        builder.create<mlir::StoreOp>(
            location, instop.bit(), bit_value,
            location, cast_bit_op.bit_result(), bit_value,
            llvm::makeArrayRef(std::vector<mlir::Value>{idx_val}));
      }
    }
+18 −6
Original line number Diff line number Diff line
@@ -134,12 +134,7 @@ LogicalResult InstOpLowering::matchAndRewrite(
                                         llvm::makeArrayRef(func_args));

  if (inst_name == "mz") {
    auto bitcast = rewriter.create<LLVM::BitcastOp>(
        loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(1)),
        c.getResult(0));
    auto o = rewriter.create<LLVM::LoadOp>(loc, rewriter.getIntegerType(1),
                                           bitcast.res());
    rewriter.replaceOp(op, o.res());
    rewriter.replaceOp(op, c.getResult(0));
  } else {
    rewriter.eraseOp(op);
  }
@@ -149,4 +144,21 @@ LogicalResult InstOpLowering::matchAndRewrite(

  return success();
}

LogicalResult ResultCastOpLowering::matchAndRewrite(
    Operation *op, ArrayRef<Value> operands,
    ConversionPatternRewriter &rewriter) const {
  auto loc = op->getLoc();
  auto resultCastOp = cast<mlir::quantum::ResultCastOp>(op);
  auto qir_result = resultCastOp.measure_result();
  // Cast Result* -> Bool* (i1*)
  auto bitcast = rewriter.create<LLVM::BitcastOp>(
      loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(1)), qir_result);
  // Load bool from bool*
  auto bool_result = rewriter.create<LLVM::LoadOp>(
      loc, rewriter.getIntegerType(1), bitcast.res());
  rewriter.replaceOp(op, bool_result.res());

  return success();
}
}  // namespace qcor
 No newline at end of file
+16 −0
Original line number Diff line number Diff line
@@ -34,4 +34,20 @@ public:
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override;
};

// Lower Result type casting:
// In QCOR QIR runtime, Result is just a bool (i1)
// hence, just need to do a type cast and load.
class ResultCastOpLowering : public ConversionPattern {
protected:
public:
  explicit ResultCastOpLowering(MLIRContext *context)
      : ConversionPattern(mlir::quantum::ResultCastOp::getOperationName(), 1,
                          context) {}

  // Match and replace all InstOps
  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace qcor
 No newline at end of file
Loading