Commit 0e931e10 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Fixed type casting in fixed-width loop to work in both MLIR and LLVM reps



This fixed the type casting unit test to run w/ MLIR opt passes.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 59afcfbc
Loading
Loading
Loading
Loading
+44 −19
Original line number Diff line number Diff line
@@ -34,7 +34,13 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
    // std::cout << "TERMNODE:\n";
    indexed_variable_value = current_value;
    if (casting_indexed_integer_to_bool) {
      if (indexed_variable_value.getType().isa<mlir::MemRefType>()) {
        internal_value_type = indexed_variable_value.getType()
                                  .cast<mlir::MemRefType>()
                                  .getElementType();
      } else {
        internal_value_type = builder.getIndexType();
      }
    } else if (indexed_variable_value.getType().isa<mlir::MemRefType>()) {
      internal_value_type = builder.getIndexType();
    }
@@ -63,28 +69,40 @@ antlrcpp::Any qasm3_expression_generator::visitTerminal(
      // builder.create<mlir::SubIOp>(location, current_value,
      // get_or_create_constant_integer_value(1, location));
      auto bw = indexed_variable_value.getType().getIntOrFloatBitWidth();
      auto casted_idx =
          builder.create<mlir::IndexCastOp>(location, current_value,
                                            indexed_variable_value.getType()
      
      // NOTE: UnsignedShiftRightOp (std dialect) expects operands of type "signless-integer-like"
      // i.e. although it treats the operants as unsigned, they must be of type signless (int not uint). 
      // https://mlir.llvm.org/docs/Dialects/Standard/#stdshift_right_unsigned-mlirunsignedshiftrightop

      // This is the type for all variables involved in this procedure:
      mlir::IntegerType signless_integer_like_type =
          builder.getIntegerType(indexed_variable_value.getType()
                                     .cast<mlir::MemRefType>()
                                                .getElementType());
      auto load_value = builder.create<mlir::LoadOp>(
          location, indexed_variable_value,
          get_or_create_constant_index_value(0, location, 64, symbol_table,
                                             builder));
                                     .getElementType()
                                     .getIntOrFloatBitWidth());
      auto casted_idx =
          builder
              .create<mlir::quantum::IntegerCastOp>(
                  location, signless_integer_like_type, current_value)
              .output();
      auto load_value =
          builder.create<mlir::LoadOp>(location, indexed_variable_value);

      auto load_value_casted =
          builder
              .create<mlir::quantum::IntegerCastOp>(
                  location, signless_integer_like_type, load_value)
              .output();
      // Note: 'std.shift_right_unsigned' op requires the same type for
      // all operands and results
      assert(load_value_casted.getType() == casted_idx.getType());
      auto shift = builder.create<mlir::UnsignedShiftRightOp>(
          location, load_value, casted_idx);
          location, load_value_casted, casted_idx);

      auto old_int_type = internal_value_type;
      internal_value_type = indexed_variable_value.getType();
      auto and_value = builder.create<mlir::AndOp>(
          location, shift,
          get_or_create_constant_integer_value(1, location,
                                               indexed_variable_value.getType()
                                                   .cast<mlir::MemRefType>()
                                                   .getElementType(),
                                               symbol_table, builder));
      internal_value_type = old_int_type;
          get_or_create_constant_integer_value(
              1, location, signless_integer_like_type, symbol_table, builder));
      update_current_value(and_value.result());
      casting_indexed_integer_to_bool = false;
    } else {
@@ -1022,8 +1040,15 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
                                                     zero_index)
                               .result();

              mlir::Value j_val_as_index = builder.create<mlir::IndexCastOp>(
                                location, j_val, builder.getIndexType());

              auto load_bit_j =
                  builder.create<mlir::LoadOp>(location, var_to_cast, j_val);
                  j_val.getType().isa<mlir::IndexType>()
                      ? builder.create<mlir::LoadOp>(location, var_to_cast,
                                                     j_val)
                      : builder.create<mlir::LoadOp>(location, var_to_cast,
                                                     j_val_as_index);
              // Extend i1 to the same width as i
              auto load_j_ext = builder.create<mlir::ZeroExtendIOp>(
                  location, load_bit_j, int_value_type);
+23 −6
Original line number Diff line number Diff line
@@ -93,6 +93,22 @@ mlir::Value get_or_create_constant_integer_value(
  auto width = type.getIntOrFloatBitWidth();
  if (symbol_table.has_constant_integer(idx, width)) {
    return symbol_table.get_constant_integer(idx, width);
  } else {
    // Handle unsigned int constant:
    // ConstantOp (std dialect) doesn't support Signed type (uint)
    if (!type.cast<mlir::IntegerType>().isSignless()) {
      auto signless_int_type =
          builder.getIntegerType(type.getIntOrFloatBitWidth());
      auto integer_attr = mlir::IntegerAttr::get(signless_int_type, idx);

      auto ret =
          builder
              .create<mlir::quantum::IntegerCastOp>(
                  location, type,
                  builder.create<mlir::ConstantOp>(location, integer_attr))
              .output();
      symbol_table.add_constant_integer(idx, ret, width);
      return ret;
    } else {
      auto integer_attr = mlir::IntegerAttr::get(type, idx);
      assert(integer_attr.getType().cast<mlir::IntegerType>().isSignless());
@@ -101,6 +117,7 @@ mlir::Value get_or_create_constant_integer_value(
      return ret;
    }
  }
}

mlir::Value get_or_create_constant_index_value(const std::size_t idx,
                                               mlir::Location location,