Commit 29de0aba authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

bug fixes, work on subroutine return to handle return statements for handler kernels

parent a6ee647d
Loading
Loading
Loading
Loading
+183 −141
Original line number Diff line number Diff line
@@ -7,13 +7,47 @@

namespace qcor {

void OpenQasmV3MLIRGenerator::initialize_mlirgen(
    const std::string func_name, std::vector<mlir::Type> arg_types,
    std::vector<std::string> arg_var_names, mlir::Type return_type) {
  mlir::FunctionType func_type2;
  if (return_type) {
    func_type2 =
        builder.getFunctionType(llvm::makeArrayRef(arg_types), return_type);
  } else {
      func_type2 =
          builder.getFunctionType(llvm::makeArrayRef(arg_types), llvm::None);
    }
    auto proto2 = mlir::FuncOp::create(
        builder.getUnknownLoc(), "__internal_mlir_" + func_name, func_type2);
    mlir::FuncOp function2(proto2);
    std::string file_name = "internal_mlirgen_qcor_";
    auto save_main_entry_block = function2.addEntryBlock();
    builder.setInsertionPointToStart(save_main_entry_block);
    m_module.push_back(function2);
    main_entry_block = save_main_entry_block;

    // Configure block arguments
    visitor = std::make_shared<qasm3_visitor>(builder, m_module, file_name);
    auto symbol_table = visitor->getScopedSymbolTable();
    auto arguments = main_entry_block->getArguments();
    for (int i = 0; i < arg_var_names.size(); i++) {
      symbol_table->add_symbol(arg_var_names[i], arguments[i]);
    }

    add_main = false;
    if (!return_type) {
      add_custom_return = true;
    }

    return;
  }

  void OpenQasmV3MLIRGenerator::initialize_mlirgen(bool _add_entry_point,
                                                   const std::string function) {
    file_name = function;
    add_entry_point = _add_entry_point;

  m_module = mlir::ModuleOp::create(builder.getUnknownLoc());

    // Useful opaque type defs
    llvm::StringRef qubit_type_name("Qubit"), array_type_name("Array"),
        result_type_name("Result");
@@ -38,8 +72,8 @@ void OpenQasmV3MLIRGenerator::initialize_mlirgen(bool _add_entry_point,

      if (add_entry_point) {
        std::vector<mlir::Type> arg_types_vec{int_type, argv_type};
      auto func_type =
          builder.getFunctionType(llvm::makeArrayRef(arg_types_vec), int_type);
        auto func_type = builder.getFunctionType(
            llvm::makeArrayRef(arg_types_vec), int_type);
        auto proto =
            mlir::FuncOp::create(builder.getUnknownLoc(), "main", func_type);
        mlir::FuncOp function(proto);
@@ -97,7 +131,9 @@ void OpenQasmV3MLIRGenerator::mlirgen(const std::string &src) {
    using namespace antlr4;
    using namespace qasm3;

    if (!visitor) {
      visitor = std::make_shared<qasm3_visitor>(builder, m_module, file_name);
    }

    ANTLRInputStream input(src);
    qasm3Lexer lexer(&input);
@@ -139,34 +175,40 @@ void OpenQasmV3MLIRGenerator::mlirgen(const std::string &src) {

  void OpenQasmV3MLIRGenerator::finalize_mlirgen() {
    auto scoped_symbol_table = visitor->getScopedSymbolTable();
  if (auto b = scoped_symbol_table.get_last_created_block()) {
    if (auto b = scoped_symbol_table->get_last_created_block()) {
      builder.setInsertionPointToEnd(b);
    }
    auto all_qalloc_ops =
      scoped_symbol_table.get_global_symbols_of_type<mlir::quantum::QallocOp>();
        scoped_symbol_table
            ->get_global_symbols_of_type<mlir::quantum::QallocOp>();
    for (auto op : all_qalloc_ops) {
      builder.create<mlir::quantum::DeallocOp>(builder.getUnknownLoc(), op);
    }

    // Add any function names that we created.
  auto fnames = scoped_symbol_table.get_seen_function_names();
    auto fnames = scoped_symbol_table->get_seen_function_names();
    for (auto f : fnames) {
      function_names.push_back(f);
    }

    if (add_main) {
    if (auto b = scoped_symbol_table.get_last_created_block()) {
      if (auto b = scoped_symbol_table->get_last_created_block()) {
        builder.setInsertionPointToEnd(b);
      } else {
        builder.setInsertionPointToEnd(main_entry_block);
      }

      auto integer_attr = mlir::IntegerAttr::get(builder.getI32Type(), 0);
    auto ret =
        builder.create<mlir::ConstantOp>(builder.getUnknownLoc(), integer_attr);
      auto ret = builder.create<mlir::ConstantOp>(builder.getUnknownLoc(),
                                                  integer_attr);
      builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                     llvm::ArrayRef<mlir::Value>(ret));
    }

    if (add_custom_return) {
      builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                     llvm::ArrayRef<mlir::Value>());
    }
  }

}  // namespace qcor
 No newline at end of file
+16 −1
Original line number Diff line number Diff line
@@ -10,6 +10,10 @@ class OpenQasmV3MLIRGenerator : public qcor::QuantumMLIRGenerator {
 protected:
  std::string file_name = "main";
  bool add_entry_point = true;
  bool add_custom_return = false;

  mlir::Type return_type;

  mlir::Type qubit_type;
  mlir::Type array_type;
  mlir::Type result_type;
@@ -20,7 +24,18 @@ class OpenQasmV3MLIRGenerator : public qcor::QuantumMLIRGenerator {

 public:
  OpenQasmV3MLIRGenerator(mlir::MLIRContext &context)
      : QuantumMLIRGenerator(context) {}
      : QuantumMLIRGenerator(context) {
    m_module = mlir::ModuleOp::create(builder.getUnknownLoc());
  }
  OpenQasmV3MLIRGenerator(mlir::OpBuilder b, mlir::MLIRContext &ctx)
      : QuantumMLIRGenerator(b, ctx) {
    m_module = mlir::ModuleOp::create(builder.getUnknownLoc());
  }

  void initialize_mlirgen(const std::string func_name,
                          std::vector<mlir::Type> arg_types,
                          std::vector<std::string> arg_var_names,
                          mlir::Type return_type);
  void initialize_mlirgen(bool add_entry_point = true,
                          const std::string file_name = "") override;
  void mlirgen(const std::string &src) override;
+1 −1
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ namespace qcor {
class qasm3_visitor : public qasm3::qasm3BaseVisitor {
 public:
  // Return the symbol table.
  ScopedSymbolTable& getScopedSymbolTable() { return symbol_table; }
  ScopedSymbolTable* getScopedSymbolTable() { return &symbol_table; }

  // The constructor, instantiates commonly used opaque types
  qasm3_visitor(mlir::OpBuilder b, mlir::ModuleOp m, std::string& fname)
+21 −12
Original line number Diff line number Diff line
@@ -166,7 +166,8 @@ antlrcpp::Any qasm3_visitor::visitReturnStatement(
    value = symbol_table.get_symbol(ret_stmt);
    // Actually return value if it is a bit[],
    // load and return if it is a bit
    // printErrorMessage("Putting this here til I fix this");

    if (current_function_return_type) {  // this means it is a subroutine
      if (!current_function_return_type.isa<mlir::MemRefType>()) {
        if (current_function_return_type.isa<mlir::IntegerType>() &&
            current_function_return_type.getIntOrFloatBitWidth() == 1) {
@@ -176,10 +177,19 @@ antlrcpp::Any qasm3_visitor::visitReturnStatement(
          llvm::ArrayRef<mlir::Value> zero_index(tmp);
          value = builder.create<mlir::LoadOp>(location, value, zero_index);
        } else {
        value = builder.create<mlir::LoadOp>(location, value);  //, zero_index);
          value =
              builder.create<mlir::LoadOp>(location, value);  //, zero_index);
        }
      } else {
        printErrorMessage("We do not return memrefs from subroutines.",
                          context);
      }
    } else {
      printErrorMessage("We do not return memrefs from subroutines.", context);
      if (auto t = value.getType().dyn_cast_or_null<mlir::MemRefType>()) {
        if (t.getRank() == 0) {
          value = builder.create<mlir::LoadOp>(location, value);
        }
      }
    }

  } else {
@@ -192,7 +202,6 @@ antlrcpp::Any qasm3_visitor::visitReturnStatement(
      visitChildren(context->statement());
      value = symbol_table.get_last_value_added();
    }
    
  }
  is_return_stmt = false;

+13 −0
Original line number Diff line number Diff line
@@ -114,6 +114,11 @@ LogicalResult PrintOpLowering::matchAndRewrite(
          frmt_spec += "%d";
        }
        ss << "_bit_array_b_" << dim;
      } else if (mem_ref_type.getElementType().isa<mlir::IntegerType>() &&
                 mem_ref_type.getRank() == 0 &&
                 mem_ref_type.getElementType().getIntOrFloatBitWidth() == 1) {
        frmt_spec += "%d";
        ss << "_bit_array_b_0";
      }
    } else {
      std::cout << "Currently invalid type to print.\n";
@@ -155,6 +160,7 @@ LogicalResult PrintOpLowering::matchAndRewrite(
          mem_ref_type.getRank() > 0 &&
          mem_ref_type.getElementType().getIntOrFloatBitWidth() == 1) {
        // This is a bit array...

        auto dim = mem_ref_type.getShape()[0];
        for (int i = 0; i < dim; i++) {
          auto attr = mlir::IntegerAttr::get(rewriter.getIndexType(), i);
@@ -163,6 +169,13 @@ LogicalResult PrintOpLowering::matchAndRewrite(
              loc, o, llvm::makeArrayRef(std::vector<mlir::Value>{ii}));
          args.push_back(z);
        }

        continue;
      } else if (mem_ref_type.getElementType().isa<mlir::IntegerType>() &&
                 mem_ref_type.getRank() == 0 &&
                 mem_ref_type.getElementType().getIntOrFloatBitWidth() == 1) {
        auto z = rewriter.create<mlir::LoadOp>(loc, o);
        args.push_back(z);
        continue;
      }
    }