Commit 4d46b7ff authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

general cleanup, adding more documentation

parent 51692cc0
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -42,4 +42,4 @@ endif()

install(TARGETS ${LIBRARY_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/lib)

#add_subdirectory(tests)
 No newline at end of file
add_subdirectory(tests)
 No newline at end of file
+42 −37
Original line number Diff line number Diff line
@@ -19,10 +19,17 @@ using namespace qasm3;

namespace qcor {

// This class provides a set of visitor methods for the
// various nodes of the auto-generated qasm3.g4 Antlr parse tree.
// It keeps track of the translation unit symbol table and the MLIR
// OpBuilder and its goal is to build up an MLIR representation of
// the qasm3 source code using the QuantumDialect and the StdDialect.
class qasm3_visitor : public qasm3::qasm3BaseVisitor {
 public:
  // Return the symbol table.
  ScopedSymbolTable& getScopedSymbolTable() { return symbol_table; }

  // The constructor, instantiates commonly used opaque types
  qasm3_visitor(mlir::OpBuilder b, mlir::ModuleOp m, std::string& fname)
      : builder(b), file_name(fname), m_module(m) {
    auto context = b.getContext();
@@ -34,12 +41,15 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
    result_type = mlir::IntegerType::get(context, 1);
  }

  // see visitor_handlers/quantum_types_handler.cpp
  // Visit nodes corresponding to quantum variable and gate declarations.
  // see visitor_handlers/quantum_types_handler.cpp for implementation
  antlrcpp::Any visitQuantumGateDefinition(
      qasm3Parser::QuantumGateDefinitionContext* context) override;
  antlrcpp::Any visitQuantumDeclaration(
      qasm3Parser::QuantumDeclarationContext* context) override;

  // Visit nodes corresponding to quantum gate, subroutine, and 
  // kernel calls. 
  // see visitor_handlers/quantum_instruction_handler.cpp
  antlrcpp::Any visitQuantumGateCall(
      qasm3Parser::QuantumGateCallContext* context) override;
@@ -48,28 +58,37 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  antlrcpp::Any visitKernelCall(
      qasm3Parser::KernelCallContext* context) override;

  // Visit nodes corresponding to quantum measurement and 
  // measurement assignment 
  // see visitor_handlers/measurement_handler.cpp
  antlrcpp::Any visitQuantumMeasurement(
      qasm3Parser::QuantumMeasurementContext* context) override;
  antlrcpp::Any visitQuantumMeasurementAssignment(
      qasm3Parser::QuantumMeasurementAssignmentContext* context) override;

  // // see visitor_handlers/subroutine_handler.cpp
  // Visit nodes corresponding to subroutine definitions 
  // and corresponding return statements
  // see visitor_handlers/subroutine_handler.cpp
  antlrcpp::Any visitSubroutineDefinition(
      qasm3Parser::SubroutineDefinitionContext* context) override;
  antlrcpp::Any visitReturnStatement(
      qasm3Parser::ReturnStatementContext* context) override;

  // Visit nodes corresponding to if/else branching statements
  // see visitor_handlers/conditional_handler.cpp
  antlrcpp::Any visitBranchingStatement(
      qasm3Parser::BranchingStatementContext* context) override;

  // Visit nodes corresponding to for and while loop statements
  // see visitor_handlers/for_stmt_handler.cpp
  antlrcpp::Any visitLoopStatement(
      qasm3Parser::LoopStatementContext* context) override;
  antlrcpp::Any visitControlDirective(
      qasm3Parser::ControlDirectiveContext* context) override;

  // Visit nodes related to classical variable declrations - 
  // constants, int, float, bit, etc and then assignments to 
  // those variables. 
  // see visitor_handlers/classical_types_handler.cpp
  antlrcpp::Any visitConstantDeclaration(
      qasm3Parser::ConstantDeclarationContext* context) override;
@@ -82,25 +101,6 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  antlrcpp::Any visitClassicalAssignment(
      qasm3Parser::ClassicalAssignmentContext* context) override;

  // antlrcpp::Any visitExpression(
  //     qasm3Parser::ExpressionContext* context) override {
  //   if (context->incrementor()) {
  //     qasm3_expression_generator exp_generator(builder, symbol_table,
  //                                              file_name);
  //     exp_generator.visit(context);
  //     auto expr_value = exp_generator.current_value;
  //     return 0;
  //   }
  //   return visitChildren(context);
  // }
  // --------//

  // The last block added by either loop or if stmts
  mlir::Block* current_block;

  mlir::Block* current_loop_exit_block;
  mlir::Block* current_loop_header_block;
  mlir::Block* current_loop_incrementor_block;
 protected:
  // Reference to the MLIR OpBuilder and ModuleOp
  // this MLIRGen task
@@ -108,30 +108,34 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
  mlir::ModuleOp m_module;
  std::string file_name = "";

  std::size_t current_scope = 0;
  // We keep reference to these blocks so that 
  // we can handle break/continue correctly
  mlir::Block* current_loop_exit_block;
  mlir::Block* current_loop_header_block;
  mlir::Block* current_loop_incrementor_block;

  // The symbol table, keeps track of current scope
  ScopedSymbolTable symbol_table;

  bool at_global_scope = true;
  // Booleans used for indicating how to construct
  // return statement for subroutines
  bool subroutine_return_statment_added = false;
  bool is_return_stmt = false;

  // Keep track of expected subroutine return type
  mlir::Type current_function_return_type;

  // Reference to MLIR Quantum Opaque Types
  mlir::Type qubit_type;
  mlir::Type array_type;
  mlir::Type result_type;

  void createInstOps_HandleBroadcast(std::string name, std::vector<mlir::Value> qbit_values,
                                     std::vector<mlir::Value> param_values,mlir::Location location, antlr4::ParserRuleContext* context);

  void update_symbol_table(const std::string& key, mlir::Value value,
                           std::vector<std::string> variable_attributes = {},
                           bool overwrite = false) {
    symbol_table.add_symbol(key, value, variable_attributes, overwrite);
    return;
  }
  // This method will add correct number of InstOps
  // based on quantum gate broadcasting
  void createInstOps_HandleBroadcast(std::string name,
                                     std::vector<mlir::Value> qbit_values,
                                     std::vector<mlir::Value> param_values,
                                     mlir::Location location,
                                     antlr4::ParserRuleContext* context);

  // This function serves as a utility for creating a MemRef and
  // corresponding AllocOp of a given 1d shape. It will also store
@@ -145,10 +149,9 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
          "Cannot allocate and initialize memory, shape and number of initial "
          "value indices is incorrect");
    }
    llvm::ArrayRef<int64_t> shaperef{shape};

    auto mem_type = mlir::MemRefType::get(shaperef, type);
    mlir::Value allocation = builder.create<mlir::AllocaOp>(location, mem_type);
    // Allocate
    auto allocation = allocate_1d_memory(location, shape, type);
    // and initialize
    for (int i = 0; i < initial_values.size(); i++) {
      builder.create<mlir::StoreOp>(location, initial_values[i], allocation,
                                    initial_indices[i]);
@@ -156,6 +159,8 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
    return allocation;
  }

  // This function serves as a utility for creating a MemRef and
  // corresponding AllocOp of a given 1d shape.
  mlir::Value allocate_1d_memory(mlir::Location location, int64_t shape,
                                 mlir::Type type) {
    llvm::ArrayRef<int64_t> shaperef{shape};
+0 −3
Original line number Diff line number Diff line
@@ -22,9 +22,6 @@ class qasm3_expression_generator : public qasm3::qasm3BaseVisitor {
  bool found_negation_unary_op = false;
  mlir::Value indexed_variable_value;

  // The last block added by either loop or if stmts
  mlir::Block* current_block;

  mlir::Type internal_value_type;


+11 −0
Original line number Diff line number Diff line
@@ -21,6 +21,17 @@ void printErrorMessage(const std::string msg,
            << "   " << msg << "\n\n";
  if (do_exit) exit(1);
}
void printErrorMessage(const std::string msg, antlr4::ParserRuleContext* context, std::vector<mlir::Value>&& v, bool do_exit) {
  auto line = context->getStart()->getLine();
  auto col = context->getStart()->getCharPositionInLine();
  std::cout << "\n[OPENQASM3 MLIRGen] Error at " << line << ":" << col << "\n"
            << "   AntlrText: " << context->getText() << "\n"
            << "   " << msg << "\n\n";
  std::cout << "MLIR Values:\n";
  for (auto vv : v) vv.dump();
  
  if (do_exit) exit(1);
}

void printErrorMessage(const std::string msg, mlir::Value v) {
  printErrorMessage(msg, false);
+1 −1
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ void split(const std::string& s, char delim, Op op) {

void printErrorMessage(const std::string msg, bool do_exit = true);
void printErrorMessage(const std::string msg, antlr4::ParserRuleContext* context, bool do_exit = true);

void printErrorMessage(const std::string msg, antlr4::ParserRuleContext* context, std::vector<mlir::Value>&& v, bool do_exit = true);
void printErrorMessage(const std::string msg, mlir::Value v);
void printErrorMessage(const std::string msg, std::vector<mlir::Value>&& v);

Loading