Commit 2f388d23 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Use scf for op in the power handler



It's cleaner and we don't need to handle last created block at the upper level.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent ffd44634
Loading
Loading
Loading
Loading
+34 −90
Original line number Diff line number Diff line

#include "expression_handler.hpp"
#include "mlir/Dialect/SCF/SCF.h"

using namespace qasm3;

namespace qcor {
@@ -555,100 +557,42 @@ antlrcpp::Any qasm3_expression_generator::visitXOrExpression(
      // }

      // Will need to compute this via a loop
      auto tmp = get_or_create_constant_integer_value(
          0, location, builder.getI64Type(), symbol_table, builder);
      auto tmp2 = get_or_create_constant_index_value(0, location, 64,
                                                     symbol_table, builder);
      auto tmp3 = get_or_create_constant_integer_value(
          1, location, lhs_element_type, symbol_table, builder);
      llvm::ArrayRef<mlir::Value> zero_index(tmp2);
      // Upper bound = exponent (rhs)
      mlir::Value ubs_val = rhs;
      if (!ubs_val.getType().isa<mlir::IndexType>()) {
        ubs_val = builder.create<mlir::IndexCastOp>(
            location, builder.getIndexType(), ubs_val);
      }

      // Create the result memref, initialized to 1
      llvm::ArrayRef<int64_t> shaperef{};
      auto mem_type = mlir::MemRefType::get(shaperef, lhs_element_type);

      auto integer_attr2 = mlir::IntegerAttr::get(lhs_element_type, 0);
      
      assert(integer_attr2.getType().cast<mlir::IntegerType>().isSignless());
      auto ret2 = builder.create<mlir::ConstantOp>(location, integer_attr2);
      
      auto integer_attr3 = mlir::IntegerAttr::get(lhs_element_type, 1);
      assert(integer_attr3.getType().cast<mlir::IntegerType>().isSignless());
      auto ret3 = builder.create<mlir::ConstantOp>(location, integer_attr3);

      mlir::Value loop_var_memref = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(shaperef, builder.getI64Type()));
      builder.create<mlir::StoreOp>(location, ret2, loop_var_memref);

      mlir::Value product_memref = builder.create<mlir::AllocaOp>(
          location, mlir::MemRefType::get(shaperef, lhs_element_type));
      builder.create<mlir::StoreOp>(location, ret3, product_memref);

      auto integer_attr = mlir::IntegerAttr::get(builder.getI64Type(), 1);
      auto ret = builder.create<mlir::ConstantOp>(location, integer_attr);
      auto b_val = rhs;
      auto c_val = ret;

      // Save the current builder point
      // auto savept = builder.saveInsertionPoint();
      auto loaded_var = builder.create<mlir::LoadOp>(
          location, loop_var_memref);  //, zero_index);

      auto savept = builder.saveInsertionPoint();
      auto currRegion = builder.getBlock()->getParent();
      auto headerBlock = builder.createBlock(currRegion, currRegion->end());
      auto bodyBlock = builder.createBlock(currRegion, currRegion->end());
      auto incBlock = builder.createBlock(currRegion, currRegion->end());
      mlir::Block* exitBlock =
          builder.createBlock(currRegion, currRegion->end());
      builder.restoreInsertionPoint(savept);

      builder.create<mlir::BranchOp>(location, headerBlock);
      builder.setInsertionPointToStart(headerBlock);

      auto load = builder.create<mlir::LoadOp>(location, loop_var_memref)
                      .result();  //, zero_index);

      if (load.getType().getIntOrFloatBitWidth() <
          b_val.getType().getIntOrFloatBitWidth()) {
        load = builder.create<mlir::ZeroExtendIOp>(location, load,
                                                   b_val.getType());
      } else if (b_val.getType().getIntOrFloatBitWidth() <
                 load.getType().getIntOrFloatBitWidth()) {
        b_val = builder.create<mlir::ZeroExtendIOp>(location, b_val,
                                                    load.getType());
      }

      auto cmp = builder.create<mlir::CmpIOp>(
          location, mlir::CmpIPredicate::slt, load, b_val);
      builder.create<mlir::CondBranchOp>(location, cmp, bodyBlock, exitBlock);

      builder.setInsertionPointToStart(bodyBlock);
      // body needs to load the loop variable
      auto x = builder.create<mlir::LoadOp>(location,
                                            product_memref);  //, zero_index);

      auto mult = builder.create<mlir::MulIOp>(location, x, lhs);
      builder.create<mlir::StoreOp>(location, mult, product_memref);  //,
      // llvm::makeArrayRef(std::vector<mlir::Value>{tmp2}));

      builder.create<mlir::BranchOp>(location, incBlock);

      builder.setInsertionPointToStart(incBlock);
      auto load_inc = builder.create<mlir::LoadOp>(
          location, loop_var_memref);  //, zero_index);
      auto add = builder.create<mlir::AddIOp>(location, load_inc, c_val);

      builder.create<mlir::StoreOp>(location, add, loop_var_memref);  //,
      // llvm::makeArrayRef(std::vector<mlir::Value>{tmp2}));

      builder.create<mlir::BranchOp>(location, headerBlock);

      builder.setInsertionPointToStart(exitBlock);

      symbol_table.set_last_created_block(exitBlock);
      builder.create<mlir::StoreOp>(
          location,
          builder.create<mlir::ConstantOp>(
              location, mlir::IntegerAttr::get(lhs_element_type, 1)),
          product_memref);

      current_value = builder.create<mlir::LoadOp>(
          location, product_memref);  //, zero_index);
      // Lower bound = 0
      mlir::Value lbs_val = get_or_create_constant_index_value(
          0, location, 64, symbol_table, builder);
      mlir::Value step_val = get_or_create_constant_index_value(
          1, location, 64, symbol_table, builder);
      builder.create<mlir::scf::ForOp>(
          location, lbs_val, ubs_val, step_val, llvm::None,
          [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc,
              mlir::Value iv, mlir::ValueRange itrArgs) {
            mlir::OpBuilder::InsertionGuard guard(nestedBuilder);
            // Load - Multiply - Store
            auto x =
                nestedBuilder.create<mlir::LoadOp>(nestedLoc, product_memref);
            auto mult = nestedBuilder.create<mlir::MulIOp>(nestedLoc, x, lhs);
            nestedBuilder.create<mlir::StoreOp>(nestedLoc, mult,
                                                product_memref);
            nestedBuilder.create<mlir::scf::YieldOp>(nestedLoc);
          });
      current_value = builder.create<mlir::LoadOp>(location, product_memref);
      return 0;
    } else if (lhs_element_type.isa<mlir::FloatType>()) {
      if (!rhs_element_type.isa<mlir::FloatType>()) {