Commit f354f4f2 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Move QCOR_EXPECT_TRUE handler to a cpp file



Rewrite using an if block.

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 3fef48ef
Loading
Loading
Loading
Loading
+2 −62
Original line number Diff line number Diff line
@@ -114,64 +114,9 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  // Visit the compute-action-uncompute expression
  antlrcpp::Any visitCompute_action_stmt(qasm3Parser::Compute_action_stmtContext *context) override;

  // QCOR_EXPECT_TRUE handler
  antlrcpp::Any visitQcor_test_statement(
      qasm3Parser::Qcor_test_statementContext* context) override {
    auto location = get_location(builder, file_name, context);

    auto boolean_expr = context->booleanExpression();
    qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
    exp_generator.visit(boolean_expr);
    auto expr_value = exp_generator.current_value;

    // So we have a conditional result, want
    // to negate it and see if == true
    expr_value = builder.create<mlir::CmpIOp>(
        location, mlir::CmpIPredicate::ne, expr_value,
        get_or_create_constant_integer_value(1, location, builder.getI1Type(),
                                             symbol_table, builder));

    auto currRegion = builder.getBlock()->getParent();
    auto savept = builder.saveInsertionPoint();
    auto thenBlock = builder.createBlock(currRegion, currRegion->end());
    mlir::Block* exitBlock = builder.createBlock(currRegion, currRegion->end());

    // Build up the THEN Block
    builder.setInsertionPointToStart(thenBlock);

    auto sl = "QCOR Test Failure: " + context->getText() + "\n";
    llvm::StringRef string_type_name("StringType");
    mlir::Identifier dialect =
        mlir::Identifier::get("quantum", builder.getContext());
    auto str_type =
        mlir::OpaqueType::get(builder.getContext(), dialect, string_type_name);
    auto str_attr = builder.getStringAttr(sl);

    std::hash<std::string> hasher;
    auto hash = hasher(sl);
    std::stringstream ss;
    ss << "__internal_string_literal__" << hash;
    std::string var_name = ss.str();
    auto var_name_attr = builder.getStringAttr(var_name);

    auto string_literal = builder.create<mlir::quantum::CreateStringLiteralOp>(
        location, str_type, str_attr, var_name_attr);
    builder.create<mlir::quantum::PrintOp>(
        location, llvm::makeArrayRef(std::vector<mlir::Value>{string_literal}));

    auto integer_attr = mlir::IntegerAttr::get(builder.getI32Type(), 1);
    auto ret = builder.create<mlir::ConstantOp>(location, integer_attr);
    builder.create<mlir::ReturnOp>(location, llvm::ArrayRef<mlir::Value>(ret));

    // Restore the insertion point and create the conditional statement
    builder.restoreInsertionPoint(savept);
    builder.create<mlir::CondBranchOp>(location, expr_value, thenBlock,
                                       exitBlock);
    builder.setInsertionPointToStart(exitBlock);

    symbol_table.set_last_created_block(exitBlock);

    return 0;
  }
      qasm3Parser::Qcor_test_statementContext *context) override;

  antlrcpp::Any visitPragma(qasm3Parser::PragmaContext *ctx) override {
    // Handle the #pragma { export; } directive
@@ -192,11 +137,6 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  mlir::ModuleOp m_module;
  std::string file_name = "";
  bool enable_nisq_ifelse = false;  
  // We keep reference to these blocks so that
  // we can handle break/continue correctly
  mlir::Block* current_loop_exit_block;
  mlir::Block* current_loop_header_block;
  mlir::Block* current_loop_incrementor_block;

  // The symbol table, keeps track of current scope
  ScopedSymbolTable symbol_table;
+7 −18
Original line number Diff line number Diff line
@@ -70,34 +70,23 @@ affineLoopBuilder(mlir::Value lbs_val, mlir::Value ubs_val, int64_t step,
namespace qcor {
antlrcpp::Any
qasm3_visitor::visitLoopStatement(qasm3Parser::LoopStatementContext *context) {
  auto program_block = context->programBlock();
  const std::string program_block_str = program_block->getText();
  if (program_block_str.find("QCOR_EXPECT_TRUE") != std::string::npos) {
    // QCOR_EXPECT_TRUE will involve early escape (return)
    // hence is not compatible with Region-based dataflow style (Affine/SCF)
    // Since this QCOR_EXPECT_TRUE is for testing only.
    // Don't support it for now.
    // Will need to figure out how to make it work with Affine/SCF.
    printErrorMessage(
        "QCOR_EXPECT_TRUE in loop is not supported now. Stay tuned.", context);
  }
  auto loop_signature = context->loopSignature();
  if (auto membership_test = loop_signature->membershipTest()) {
  if (auto membership_test = context->loopSignature()->membershipTest()) {
    // this is a for loop
    auto set_declaration = membership_test->setDeclaration();
    if (set_declaration->LBRACE()) {
      // Set-based for loop:
      // e.g., for i in {1,3,5,6}
      createSetBasedForLoop(context);
    } else if (auto range = set_declaration->rangeDefinition()) {
      // this is a range definition
      //     rangeDefinition
      // : LBRACKET expression? COLON expression? ( COLON expression )? RBRACKET
      // ;
    } else if (set_declaration->rangeDefinition()) {
      // Range-based for loop
      // e.g., for i in [0:10]
      createRangeBasedForLoop(context);
    } else {
      printErrorMessage(
          "For loops must be of form 'for i in {SET}' or 'for i in [RANGE]'.");
    }
  } else {
    // While loop:
    createWhileLoop(context);
  }

+46 −0
Original line number Diff line number Diff line
#include "qasm3_visitor.hpp"
#include "mlir/Dialect/SCF/SCF.h"

namespace qcor {
antlrcpp::Any qasm3_visitor::visitQcor_test_statement(
    qasm3Parser::Qcor_test_statementContext *context) {
  auto location = get_location(builder, file_name, context);
  auto boolean_expr = context->booleanExpression();
  qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
  exp_generator.visit(boolean_expr);
  auto expr_value = exp_generator.current_value;
  // So we have a conditional result, want
  // to negate it and see if == true
  expr_value = builder.create<mlir::CmpIOp>(
      location, mlir::CmpIPredicate::ne, expr_value,
      get_or_create_constant_integer_value(1, location, builder.getI1Type(),
                                           symbol_table, builder));
  // False (not equal true is true): print message then return 1;
  auto scfIfOp = builder.create<mlir::scf::IfOp>(location, mlir::TypeRange(),
                                                 expr_value, false);
  auto thenBodyBuilder = scfIfOp.getThenBodyBuilder();
  auto sl = "QCOR Test Failure: " + context->getText() + "\n";
  llvm::StringRef string_type_name("StringType");
  mlir::Identifier dialect =
      mlir::Identifier::get("quantum", thenBodyBuilder.getContext());
  auto str_type = mlir::OpaqueType::get(thenBodyBuilder.getContext(), dialect,
                                        string_type_name);
  auto str_attr = thenBodyBuilder.getStringAttr(sl);
  std::hash<std::string> hasher;
  auto hash = hasher(sl);
  std::stringstream ss;
  ss << "__internal_string_literal__" << hash;
  std::string var_name = ss.str();
  auto var_name_attr = thenBodyBuilder.getStringAttr(var_name);
  auto string_literal =
      thenBodyBuilder.create<mlir::quantum::CreateStringLiteralOp>(
          location, str_type, str_attr, var_name_attr);
  thenBodyBuilder.create<mlir::quantum::PrintOp>(
      location, llvm::makeArrayRef(std::vector<mlir::Value>{string_literal}));
  auto integer_attr = mlir::IntegerAttr::get(thenBodyBuilder.getI32Type(), 1);
  auto ret = builder.create<mlir::ConstantOp>(location, integer_attr);
  thenBodyBuilder.create<mlir::ReturnOp>(location, llvm::ArrayRef<mlir::Value>(ret));

  return 0;
}
} // namespace qcor
 No newline at end of file