Commit c66ac227 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Fixes for loop unroll and inliner



- Type mismatches since affine loop var is always of index type.

- I don't know how to make affine loop to work with memref bounds yet; hence don't try to create an affine loop in that case

- Run inliner both before and after loop unroll to make sure CallOps are inlined

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 870ff220
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@
#include "Quantum/QuantumOps.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"

#include <iostream>
using namespace mlir;
using namespace mlir::quantum;
namespace {
@@ -29,9 +29,10 @@ struct QuantumInlinerInterface : public DialectInlinerInterface {
  // FIXME: there is a weird error when qalloc is inlined at MLIR level
  // hence, just allow VSOp to be inlined for the timebeing.
  // i.e. all quantum subroutines that only contain VSOp's can be inlined.
  bool isLegalToInline(Operation *op, Region *regione, bool,
  bool isLegalToInline(Operation *op, Region *region, bool,
                       BlockAndValueMapping &) const final {
    if (dyn_cast_or_null<mlir::quantum::ValueSemanticsInstOp>(op)) {
    if (dyn_cast_or_null<mlir::quantum::ValueSemanticsInstOp>(op) ||
        dyn_cast_or_null<mlir::quantum::ExtractQubitOp>(op)) {
      return true;
    }

+8 −0
Original line number Diff line number Diff line
@@ -122,6 +122,10 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
          current_value = builder.create<mlir::ZeroExtendIOp>(
              location, current_value, builder.getI64Type());
        }
        if (!current_value.getType().isa<mlir::IntegerType>()) {
          current_value = builder.create<mlir::IndexCastOp>(
              location, builder.getI64Type(), current_value);
        }
        update_current_value(builder.create<mlir::quantum::ExtractQubitOp>(
            location, get_custom_opaque_type("Qubit", builder.getContext()),
            indexed_variable_value, current_value));
@@ -418,6 +422,10 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression(
          rhs =
              builder.create<mlir::ZeroExtendIOp>(location, rhs, lhs.getType());
        }

        if (lhs.getType() != rhs.getType()) {
          rhs = builder.create<mlir::IndexCastOp>(location, lhs.getType(), rhs);
        }
        createOp<mlir::AddIOp>(location, lhs, rhs).result();
      } else {
        printErrorMessage("Could not perform addition, incompatible types: ",
+11 −2
Original line number Diff line number Diff line
@@ -191,12 +191,18 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
          c_value = get_or_create_constant_integer_value(c, location, int_type,
                                                         symbol_table, builder);
      
      // Either a_value or b_value (loop bounds) is a memref
      // (For some reason, affine loop inliner doesn't work in this case, 
      // causing some validation errors)
      bool loop_bounds_are_memref = false;
      
      qasm3_expression_generator exp_generator(builder, symbol_table,
                                               file_name);
      exp_generator.visit(range->expression(0));
      a_value = exp_generator.current_value;
      if (a_value.getType().isa<mlir::MemRefType>()) {
        a_value = builder.create<mlir::LoadOp>(location, a_value);
        loop_bounds_are_memref = true;
      }

      if (n_expr == 3) {
@@ -206,6 +212,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        b_value = exp_generator.current_value;
        if (b_value.getType().isa<mlir::MemRefType>()) {
          b_value = builder.create<mlir::LoadOp>(location, b_value);
          loop_bounds_are_memref = true;
        }

        if (symbol_table.has_symbol(range->expression(1)->getText())) {
@@ -231,6 +238,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        b_value = exp_generator.current_value;
        if (b_value.getType().isa<mlir::MemRefType>()) {
          b_value = builder.create<mlir::LoadOp>(location, b_value);
          loop_bounds_are_memref = true;
        }
      }

@@ -239,7 +247,8 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(

      // HACK: Currently, we don't handle 'if', 'break', 'continue'
      // in the Affine for loop yet.
      if (program_block_str.find("if") == std::string::npos &&
      if (!loop_bounds_are_memref &&
          program_block_str.find("if") == std::string::npos &&
          program_block_str.find("break") == std::string::npos &&
          program_block_str.find("continue") == std::string::npos) {
        // Can use Affine for loop....
+4 −1
Original line number Diff line number Diff line
@@ -94,7 +94,10 @@ antlrcpp::Any qasm3_visitor::visitQuantumMeasurementAssignment(
              mlir::Identifier::get("quantum", builder.getContext());
          auto qubit_type = mlir::OpaqueType::get(builder.getContext(), dialect,
                                                  qubit_type_name);

          if (!qbit.getType().isa<mlir::IntegerType>()) {
            qbit = builder.create<mlir::IndexCastOp>(
                location, builder.getI64Type(), qbit);
          }
          value = builder.create<mlir::quantum::ExtractQubitOp>(
              location, qubit_type, qubits, qbit);
        } else {
+10 −1
Original line number Diff line number Diff line
@@ -317,7 +317,10 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(

          auto qubit_type =
              get_custom_opaque_type("Qubit", builder.getContext());

          if (!qbit.getType().isa<mlir::IntegerType>()) {
            qbit = builder.create<mlir::IndexCastOp>(
                location, builder.getI64Type(), qbit);
          }
          value = builder.create<mlir::quantum::ExtractQubitOp>(
              location, qubit_type, qubits, qbit);
          if (!symbol_table.has_symbol(qbit_var_name + idx_str))
@@ -337,6 +340,12 @@ antlrcpp::Any qasm3_visitor::visitQuantumGateCall(
              value = builder.create<mlir::ZeroExtendIOp>(location, value,
                                                          builder.getI64Type());
            }

            if (!value.getType().isa<mlir::IntegerType>()) {
              value = builder.create<mlir::IndexCastOp>(
                  location, builder.getI64Type(), value);
            }

            value = builder.create<mlir::quantum::ExtractQubitOp>(
                location, qubit_type, qubits, value);
            if (!symbol_table.has_symbol(qbit_var_name + idx_str))
Loading