Commit 59afcfbc authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Fixes to support uint type at MLIR level



Since we use std dialect for basic operations, we need to handle the uint type in our quantum dialect as well.
std dialect just doesn't like uint type.

Strategy: add a cast op to our dialect to allow the types to match up at MLIR level.
This cast will be translated to LLVM::DialectCastOp during lowering --> all types matched up.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 76cd4343
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -159,4 +159,17 @@ def ResultCastOp : QuantumOp<"resultCast", []> {
    let printer = [{  auto op = *this;
  p << "q.resultCast" << "(" << op.measure_result() << ") : " << op.bit_result().getType(); }];
}

// Sign-Unsign cast:
// Rationale: std dialect only accepts signless type (i.e. int but not uint)
// we need to have this cast op in the dialect to finally lower to LLVM cast 
// which can handle int -> uint casting at the final lowering phase.
// Note: std.index_cast cannot handle int -> unit casting (one of the type must be an index type).
def IntegerCastOp : QuantumOp<"integerCast", []> {
    let arguments = (ins AnyInteger:$input);
    let results = (outs AnyInteger:$output);
    let printer = [{  auto op = *this;
  p << "q.integerCast" << "(" << op.input() << ") : " << op.output().getType(); }];
}

#endif // Quantum_OPS
 No newline at end of file
+20 −2
Original line number Diff line number Diff line
@@ -773,13 +773,31 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
      auto idx = std::stoi(integer->getText());
      // std::cout << "Integer Terminator " << integer->getText() << ", " << idx
      // << "\n";
      const auto getSignlessIntegerType = [](mlir::OpBuilder &opBuilder,
                                             mlir::IntegerType in_intType) {
        return in_intType.isSignless()
                   ? in_intType
                   : opBuilder.getIntegerType(in_intType.getWidth());
      };
      auto integer_attr = mlir::IntegerAttr::get(
          (internal_value_type.dyn_cast_or_null<mlir::IntegerType>()
               ? internal_value_type.cast<mlir::IntegerType>()
               ? getSignlessIntegerType(
                     builder, internal_value_type.cast<mlir::IntegerType>())
               : builder.getI64Type()),
          idx);

      assert(integer_attr.getType().cast<mlir::IntegerType>().isSignless());
      current_value = builder.create<mlir::ConstantOp>(location, integer_attr);
      if (internal_value_type.dyn_cast_or_null<mlir::IntegerType>() &&
          !internal_value_type.cast<mlir::IntegerType>().isSignless()) {
        // Make sure we cast the constant value appropriately.
        // i.e. respect the signed/unsigned of the requested type.
        current_value = builder.create<mlir::quantum::IntegerCastOp>(
            location, internal_value_type.cast<mlir::IntegerType>(),
            builder.create<mlir::ConstantOp>(location, integer_attr)).output();
      } else {
        current_value =
            builder.create<mlir::ConstantOp>(location, integer_attr);
      }
    }
    return 0;
  } else if (auto real = ctx->RealNumber()) {
+13 −0
Original line number Diff line number Diff line
@@ -161,4 +161,17 @@ LogicalResult ResultCastOpLowering::matchAndRewrite(

  return success();
}

LogicalResult IntegerCastOpLowering::matchAndRewrite(
    Operation *op, ArrayRef<Value> operands,
    ConversionPatternRewriter &rewriter) const {
  auto loc = op->getLoc();
  auto resultCastOp = cast<mlir::quantum::IntegerCastOp>(op);
  auto to_be_cast = resultCastOp.input();
  mlir::IntegerType type = to_be_cast.getType().cast<mlir::IntegerType>();
  auto cast_op = rewriter.create<LLVM::DialectCastOp>(
      loc, rewriter.getIntegerType(type.getWidth(), false), to_be_cast);
  rewriter.replaceOp(op, cast_op.res());
  return success();
}
}  // namespace qcor
 No newline at end of file
+12 −1
Original line number Diff line number Diff line
@@ -45,7 +45,18 @@ public:
      : ConversionPattern(mlir::quantum::ResultCastOp::getOperationName(), 1,
                          context) {}

  // Match and replace all InstOps
  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override;
};

class IntegerCastOpLowering : public ConversionPattern {
protected:
public:
  explicit IntegerCastOpLowering(MLIRContext *context)
      : ConversionPattern(mlir::quantum::IntegerCastOp::getOperationName(), 1,
                          context) {}

  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override;
+1 −0
Original line number Diff line number Diff line
@@ -86,6 +86,7 @@ void QuantumToLLVMLoweringPass::runOnOperation() {
                                  function_names);
  patterns.insert<ValueSemanticsInstOpLowering>(&getContext(), function_names);
  patterns.insert<ResultCastOpLowering>(&getContext());
  patterns.insert<IntegerCastOpLowering>(&getContext());
  patterns.insert<SetQregOpLowering>(&getContext(), variables);
  patterns.insert<ExtractQubitOpConversion>(&getContext(), typeConverter,
                                            variables, qubit_extract_map);