Commit 50595988 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Fixed Index type for StoreOp



Making sure that we use the IndexType (not integer type) for MemRef

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 41b5b35b
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -218,10 +218,18 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
          "Cannot allocate and initialize memory, shape and number of initial "
          "value indices is incorrect");
    }

    // Assert that the values to init the memref array
    // must be of the expected type.
    for (const auto &init_val : initial_values) {
      assert(init_val.getType() == type);
    }

    // Allocate
    auto allocation = allocate_1d_memory(location, shape, type);
    // and initialize
    for (int i = 0; i < initial_values.size(); i++) {
      assert(initial_indices[i].getType().isa<mlir::IndexType>());
      builder.create<mlir::StoreOp>(location, initial_values[i], allocation,
                                    initial_indices[i]);
    }
+2 −1
Original line number Diff line number Diff line
@@ -1057,6 +1057,7 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
              auto add =
                  builder.create<mlir::AddIOp>(location, load_inc, c_val);
              
              assert(tmp2.getType().isa<mlir::IndexType>());
              builder.create<mlir::StoreOp>(
                  location, add, loop_var_memref,
                  llvm::makeArrayRef(std::vector<mlir::Value>{tmp2}));
+3 −2
Original line number Diff line number Diff line
@@ -558,8 +558,8 @@ antlrcpp::Any qasm3_visitor::visitClassicalAssignment(
        }

        for (int i = 0; i < lhs_shape; i++) {
          mlir::Value pos = get_or_create_constant_integer_value(
              i, location, builder.getIntegerType(64), symbol_table, builder);
          mlir::Value pos = get_or_create_constant_index_value(
              i, location, 64, symbol_table, builder);
          auto load = builder.create<mlir::LoadOp>(location, rhs, pos);
          builder.create<mlir::StoreOp>(
              location, load, lhs,
@@ -571,6 +571,7 @@ antlrcpp::Any qasm3_visitor::visitClassicalAssignment(
                            {lhs, rhs});
        }

        assert(v.getType().isa<mlir::IndexType>());
        builder.create<mlir::StoreOp>(
            location, rhs, lhs,
            llvm::makeArrayRef(std::vector<mlir::Value>{v}));
+7 −5
Original line number Diff line number Diff line
@@ -34,8 +34,8 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
        exp_generator.visit(exp);
        auto value = exp_generator.current_value;

        mlir::Value pos = get_or_create_constant_integer_value(
            counter, location, builder.getI64Type(), symbol_table, builder);
        mlir::Value pos = get_or_create_constant_index_value(
            counter, location, 64, symbol_table, builder);

        builder.create<mlir::StoreOp>(
            location, value, allocation,
@@ -51,9 +51,10 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
      auto tmp2 = get_or_create_constant_index_value(0, location, 64,
                                                     symbol_table, builder);
      llvm::ArrayRef<mlir::Value> zero_index(tmp2);

      // Loop var must also be an Index type
      // since we'll store the loop index values to this variable.
      auto loop_var_memref = allocate_1d_memory_and_initialize(
          location, 1, builder.getI64Type(), std::vector<mlir::Value>{tmp},
          location, 1, builder.getIndexType(), std::vector<mlir::Value>{tmp},
          llvm::makeArrayRef(std::vector<mlir::Value>{tmp}));

      auto b_val = get_or_create_constant_index_value(n_expr, location, 64,
@@ -115,6 +116,7 @@ antlrcpp::Any qasm3_visitor::visitLoopStatement(
          builder.create<mlir::LoadOp>(location, loop_var_memref, zero_index);
      auto add = builder.create<mlir::AddIOp>(location, load_inc, c_val);
      
      assert(tmp2.getType().isa<mlir::IndexType>());
      builder.create<mlir::StoreOp>(
          location, add, loop_var_memref,
          llvm::makeArrayRef(std::vector<mlir::Value>{tmp2}));
+1 −0
Original line number Diff line number Diff line
@@ -217,6 +217,7 @@ antlrcpp::Any qasm3_visitor::visitQuantumMeasurementAssignment(
          0, location, 64, symbol_table, builder);
      }

      assert(v.getType().isa<mlir::IndexType>());
      builder.create<mlir::StoreOp>(
          location, instop.bit(), bit_value,
          llvm::makeArrayRef(std::vector<mlir::Value>{v}));