Commit be437613 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Matches MLIR Op with QIR signature



we've constructed callable with captured vars for invocation.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent e8331f06
Loading
Loading
Loading
Loading
+6 −43
Original line number Diff line number Diff line
@@ -239,52 +239,15 @@ def CreateCallableOp : QuantumOp<"createCallable", []> {
    }];
}

def YieldOp : QuantumOp<"yield", [NoSideEffect, Terminator]> {
  let summary = "conditional termination operation";
  let arguments = (ins Variadic<AnyType>:$results);
  let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>];
  let printer = [{ p << "q.yield"; }];
}

def ConditionalOp : QuantumOp<"ifOp", [SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects, NoRegionArguments]> {
def ConditionalOp : QuantumOp<"ifOp", []> {
  let summary = "if-then-else operation conditioned on a quantum Measure";
  // Must be conditioned on a Result type
  let arguments = (ins ResultType:$result_bit);
  // Must be conditioned on a Result type (only then clause for now...)
  let arguments = (ins ResultType:$result_bit, CallableType:$then_callable);
  let results = (outs);
  let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
  let skipDefaultBuilders = 1;
  
  let skipDefaultBuilders = 1;
  let builders = [
    OpBuilderDAG<(ins "Value":$cond, "bool":$withElseRegion)>
  ];

  let extraClassDeclaration = [{
    OpBuilder getThenBodyBuilder(OpBuilder::Listener *listener = nullptr) {
      Block* body = getBody(0);
      return OpBuilder::atBlockTerminator(body, listener);
    }
    OpBuilder getElseBodyBuilder(OpBuilder::Listener *listener = nullptr) {
      Block* body = getBody(1);
      return OpBuilder::atBlockTerminator(body, listener);
    }
  }];

  let printer = [{
    auto op = *this;
    p << "q.If " << op.result_bit();  
    p.printRegion(op.thenRegion(),
                /*printEntryBlockArgs=*/false,
                /*printBlockTerminators=*/false);

    // Print the 'else' regions if it exists and has a block.
    auto &elseRegion = op.elseRegion();
    if (!elseRegion.empty()) {
      p << " else";
      p.printRegion(elseRegion,
                  /*printEntryBlockArgs=*/false,
                  /*printBlockTerminators=*/false);
    }
    p << "q.If " << op.result_bit() << " { invoke " << op.then_callable() << "}";
  }];
}

+1 −27
Original line number Diff line number Diff line
@@ -60,29 +60,3 @@ void QuantumDialect::initialize() {
//   printer << ")";

// }
 No newline at end of file

//===----------------------------------------------------------------------===//
// ConditionalOp
//===----------------------------------------------------------------------===//

void ConditionalOp::build(OpBuilder &builder, OperationState &result, Value cond,
                 bool withElseRegion) {
  result.addOperands(cond);
  OpBuilder::InsertionGuard guard(builder);
  Region *thenRegion = result.addRegion();
  builder.createBlock(thenRegion);
  auto defaultBuilder = [&](OpBuilder &nested, Location loc) {
    ConditionalOp::ensureTerminator(*nested.getInsertionBlock()->getParent(),
                                    nested, loc);
  };

  defaultBuilder(builder, result.location);

  Region *elseRegion = result.addRegion();
  if (!withElseRegion) {
    return;
  }

  builder.createBlock(elseRegion);
  defaultBuilder(builder, result.location);
}
 No newline at end of file
+2 −16
Original line number Diff line number Diff line
@@ -68,19 +68,13 @@ mlir::Value create_capture_callable_gen(
  mlir::Identifier dialect = mlir::Identifier::get("quantum", context);
  llvm::StringRef tuple_type_name("Tuple");
  auto tuple_type = mlir::OpaqueType::get(context, dialect, tuple_type_name);
  llvm::StringRef array_type_name("Array");
  auto array_type = mlir::OpaqueType::get(context, dialect, array_type_name);
  llvm::StringRef callable_type_name("Callable");
  auto callable_type =
      mlir::OpaqueType::get(context, dialect, callable_type_name);
  llvm::StringRef qubit_type_name("Qubit");
  auto qubit_type = mlir::OpaqueType::get(context, dialect, qubit_type_name);

  const std::vector<mlir::Type> argument_types{tuple_type, tuple_type,
                                               tuple_type};
  auto func_type = builder.getFunctionType(argument_types, llvm::None);
  const std::string BODY_WRAPPER_SUFFIX = "__body__wrapper";
  std::vector<mlir::FuncOp> all_wrapper_funcs;
  // Body wrapper:
  const std::string wrapper_fn_name = func_name + BODY_WRAPPER_SUFFIX;
  mlir::FuncOp function_op(mlir::FuncOp::create(builder.getUnknownLoc(),
@@ -128,9 +122,6 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(
    std::cout << "This is a simple Measure check\n";
    auto meas_var =
        symbol_table.try_lookup_meas_result(bit_check_conditional->var_name);
    auto ifOp =
        builder.create<mlir::quantum::ConditionalOp>(location, meas_var.value(),
                                                     /*withElseRegion=*/false);
    
    // Strategy: we wrap the body as a Callable capturing
    // all avaiable variables at the current scope.
@@ -178,13 +169,8 @@ antlrcpp::Any qasm3_visitor::visitBranchingStatement(

    auto then_body_callable = create_capture_callable_gen(
        builder, tmp_func_name, m_module, function, argument_values);
    // Create a call to the function:
    // FIXME: this should be wrapped as a callable...
    auto then_body_builder = ifOp.getThenBodyBuilder();
    auto body_call_op = then_body_builder.create<mlir::CallOp>(
        then_body_builder.getUnknownLoc(), function,
        llvm::makeArrayRef(argument_values));

    auto ifOp = builder.create<mlir::quantum::ConditionalOp>(
        location, meas_var.value(), then_body_callable);
    // Done
    return 0;
  }
+2 −0
Original line number Diff line number Diff line
@@ -418,6 +418,8 @@ void __quantum__qis__applyifelseintrinsic__body(Result *r,
      // We don't support else block atm yet.
      assert(!clb_on_zero);
      // Execute the callable: this will append NISQ instructions to the IfStmt
      // Important: implicit in this is the fact that the Callable capture the
      // whole context of the parent scope..
      clb_on_one->invoke(nullptr, nullptr);

      // Add the whole IfStmt to the program