Commit 4d46b7ff authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

general cleanup, adding more documentation


Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent 51692cc0
......@@ -42,4 +42,4 @@ endif()
install(TARGETS ${LIBRARY_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/lib)
#add_subdirectory(tests)
\ No newline at end of file
add_subdirectory(tests)
\ No newline at end of file
......@@ -19,10 +19,17 @@ using namespace qasm3;
namespace qcor {
// This class provides a set of visitor methods for the
// various nodes of the auto-generated qasm3.g4 Antlr parse tree.
// It keeps track of the translation unit symbol table and the MLIR
// OpBuilder and its goal is to build up an MLIR representation of
// the qasm3 source code using the QuantumDialect and the StdDialect.
class qasm3_visitor : public qasm3::qasm3BaseVisitor {
public:
// Return the symbol table.
ScopedSymbolTable& getScopedSymbolTable() { return symbol_table; }
// The constructor, instantiates commonly used opaque types
qasm3_visitor(mlir::OpBuilder b, mlir::ModuleOp m, std::string& fname)
: builder(b), file_name(fname), m_module(m) {
auto context = b.getContext();
......@@ -34,12 +41,15 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
result_type = mlir::IntegerType::get(context, 1);
}
// see visitor_handlers/quantum_types_handler.cpp
// Visit nodes corresponding to quantum variable and gate declarations.
// see visitor_handlers/quantum_types_handler.cpp for implementation
antlrcpp::Any visitQuantumGateDefinition(
qasm3Parser::QuantumGateDefinitionContext* context) override;
antlrcpp::Any visitQuantumDeclaration(
qasm3Parser::QuantumDeclarationContext* context) override;
// Visit nodes corresponding to quantum gate, subroutine, and
// kernel calls.
// see visitor_handlers/quantum_instruction_handler.cpp
antlrcpp::Any visitQuantumGateCall(
qasm3Parser::QuantumGateCallContext* context) override;
......@@ -48,28 +58,37 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
antlrcpp::Any visitKernelCall(
qasm3Parser::KernelCallContext* context) override;
// Visit nodes corresponding to quantum measurement and
// measurement assignment
// see visitor_handlers/measurement_handler.cpp
antlrcpp::Any visitQuantumMeasurement(
qasm3Parser::QuantumMeasurementContext* context) override;
antlrcpp::Any visitQuantumMeasurementAssignment(
qasm3Parser::QuantumMeasurementAssignmentContext* context) override;
// // see visitor_handlers/subroutine_handler.cpp
// Visit nodes corresponding to subroutine definitions
// and corresponding return statements
// see visitor_handlers/subroutine_handler.cpp
antlrcpp::Any visitSubroutineDefinition(
qasm3Parser::SubroutineDefinitionContext* context) override;
antlrcpp::Any visitReturnStatement(
qasm3Parser::ReturnStatementContext* context) override;
// Visit nodes corresponding to if/else branching statements
// see visitor_handlers/conditional_handler.cpp
antlrcpp::Any visitBranchingStatement(
qasm3Parser::BranchingStatementContext* context) override;
// Visit nodes corresponding to for and while loop statements
// see visitor_handlers/for_stmt_handler.cpp
antlrcpp::Any visitLoopStatement(
qasm3Parser::LoopStatementContext* context) override;
antlrcpp::Any visitControlDirective(
qasm3Parser::ControlDirectiveContext* context) override;
// Visit nodes related to classical variable declrations -
// constants, int, float, bit, etc and then assignments to
// those variables.
// see visitor_handlers/classical_types_handler.cpp
antlrcpp::Any visitConstantDeclaration(
qasm3Parser::ConstantDeclarationContext* context) override;
......@@ -82,25 +101,6 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
antlrcpp::Any visitClassicalAssignment(
qasm3Parser::ClassicalAssignmentContext* context) override;
// antlrcpp::Any visitExpression(
// qasm3Parser::ExpressionContext* context) override {
// if (context->incrementor()) {
// qasm3_expression_generator exp_generator(builder, symbol_table,
// file_name);
// exp_generator.visit(context);
// auto expr_value = exp_generator.current_value;
// return 0;
// }
// return visitChildren(context);
// }
// --------//
// The last block added by either loop or if stmts
mlir::Block* current_block;
mlir::Block* current_loop_exit_block;
mlir::Block* current_loop_header_block;
mlir::Block* current_loop_incrementor_block;
protected:
// Reference to the MLIR OpBuilder and ModuleOp
// this MLIRGen task
......@@ -108,30 +108,34 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
mlir::ModuleOp m_module;
std::string file_name = "";
std::size_t current_scope = 0;
// We keep reference to these blocks so that
// we can handle break/continue correctly
mlir::Block* current_loop_exit_block;
mlir::Block* current_loop_header_block;
mlir::Block* current_loop_incrementor_block;
// The symbol table, keeps track of current scope
ScopedSymbolTable symbol_table;
bool at_global_scope = true;
// Booleans used for indicating how to construct
// return statement for subroutines
bool subroutine_return_statment_added = false;
bool is_return_stmt = false;
// Keep track of expected subroutine return type
mlir::Type current_function_return_type;
// Reference to MLIR Quantum Opaque Types
mlir::Type qubit_type;
mlir::Type array_type;
mlir::Type result_type;
void createInstOps_HandleBroadcast(std::string name, std::vector<mlir::Value> qbit_values,
std::vector<mlir::Value> param_values,mlir::Location location, antlr4::ParserRuleContext* context);
void update_symbol_table(const std::string& key, mlir::Value value,
std::vector<std::string> variable_attributes = {},
bool overwrite = false) {
symbol_table.add_symbol(key, value, variable_attributes, overwrite);
return;
}
// This method will add correct number of InstOps
// based on quantum gate broadcasting
void createInstOps_HandleBroadcast(std::string name,
std::vector<mlir::Value> qbit_values,
std::vector<mlir::Value> param_values,
mlir::Location location,
antlr4::ParserRuleContext* context);
// This function serves as a utility for creating a MemRef and
// corresponding AllocOp of a given 1d shape. It will also store
......@@ -145,10 +149,9 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
"Cannot allocate and initialize memory, shape and number of initial "
"value indices is incorrect");
}
llvm::ArrayRef<int64_t> shaperef{shape};
auto mem_type = mlir::MemRefType::get(shaperef, type);
mlir::Value allocation = builder.create<mlir::AllocaOp>(location, mem_type);
// Allocate
auto allocation = allocate_1d_memory(location, shape, type);
// and initialize
for (int i = 0; i < initial_values.size(); i++) {
builder.create<mlir::StoreOp>(location, initial_values[i], allocation,
initial_indices[i]);
......@@ -156,6 +159,8 @@ class qasm3_visitor : public qasm3::qasm3BaseVisitor {
return allocation;
}
// This function serves as a utility for creating a MemRef and
// corresponding AllocOp of a given 1d shape.
mlir::Value allocate_1d_memory(mlir::Location location, int64_t shape,
mlir::Type type) {
llvm::ArrayRef<int64_t> shaperef{shape};
......
......@@ -21,9 +21,6 @@ class qasm3_expression_generator : public qasm3::qasm3BaseVisitor {
bool casting_indexed_integer_to_bool = false;
bool found_negation_unary_op = false;
mlir::Value indexed_variable_value;
// The last block added by either loop or if stmts
mlir::Block* current_block;
mlir::Type internal_value_type;
......
......@@ -21,6 +21,17 @@ void printErrorMessage(const std::string msg,
<< " " << msg << "\n\n";
if (do_exit) exit(1);
}
void printErrorMessage(const std::string msg, antlr4::ParserRuleContext* context, std::vector<mlir::Value>&& v, bool do_exit) {
auto line = context->getStart()->getLine();
auto col = context->getStart()->getCharPositionInLine();
std::cout << "\n[OPENQASM3 MLIRGen] Error at " << line << ":" << col << "\n"
<< " AntlrText: " << context->getText() << "\n"
<< " " << msg << "\n\n";
std::cout << "MLIR Values:\n";
for (auto vv : v) vv.dump();
if (do_exit) exit(1);
}
void printErrorMessage(const std::string msg, mlir::Value v) {
printErrorMessage(msg, false);
......
......@@ -26,7 +26,7 @@ void split(const std::string& s, char delim, Op op) {
void printErrorMessage(const std::string msg, bool do_exit = true);
void printErrorMessage(const std::string msg, antlr4::ParserRuleContext* context, bool do_exit = true);
void printErrorMessage(const std::string msg, antlr4::ParserRuleContext* context, std::vector<mlir::Value>&& v, bool do_exit = true);
void printErrorMessage(const std::string msg, mlir::Value v);
void printErrorMessage(const std::string msg, std::vector<mlir::Value>&& v);
......
......@@ -57,12 +57,12 @@ namespace qcor {
antlrcpp::Any qasm3_visitor::visitConstantDeclaration(
qasm3Parser::ConstantDeclarationContext* context) {
auto ass_list = context->equalsAssignmentList(); // :)
auto ass_list = context->equalsAssignmentList();
for (int i = 0; i < ass_list->Identifier().size(); i++) {
auto var_name = ass_list->Identifier(i)->getText();
if (var_name == "pi") {
printErrorMessage("pi is already defined in OPENQASM 3.");
printErrorMessage("pi is already defined in OPENQASM 3.", context);
}
auto equals_expr = ass_list->equalsExpression(i);
......@@ -70,7 +70,7 @@ antlrcpp::Any qasm3_visitor::visitConstantDeclaration(
exp_generator.visit(equals_expr);
auto expr_value = exp_generator.current_value;
update_symbol_table(var_name, expr_value, {"const"});
symbol_table.add_symbol(var_name, expr_value, {"const"});
}
return 0;
......@@ -81,22 +81,9 @@ antlrcpp::Any qasm3_visitor::visitSingleDesignatorDeclaration(
auto location = get_location(builder, file_name, context);
auto type = context->singleDesignatorType()->getText();
auto designator_expr = context->designator()->expression();
uint64_t width_idx;
{
// I only want the IDX in TYPE[IDX], don't need to add it the
// the current module... This will tell me the bit width
mlir::OpBuilder tmp_builder(builder.getContext());
ScopedSymbolTable tmp_table;
qasm3_expression_generator designator_exp_generator(
tmp_builder, tmp_table, file_name, builder.getIntegerType(64));
designator_exp_generator.visit(designator_expr);
auto designator_value = designator_exp_generator.current_value;
width_idx = designator_value.getDefiningOp<mlir::ConstantOp>()
.getValue()
.cast<mlir::IntegerAttr>()
.getUInt();
}
uint64_t width_idx = symbol_table.evaluate_constant_integer_expression(
designator_expr->getText());
mlir::Attribute init_attr;
mlir::Type value_type;
if (type == "int") {
......@@ -119,10 +106,12 @@ antlrcpp::Any qasm3_visitor::visitSingleDesignatorDeclaration(
init_attr = mlir::FloatAttr::get(value_type, 0.0);
} else {
printErrorMessage("we only support 16, 32, and 64 floating point types.");
printErrorMessage("we only support 16, 32, and 64 floating point types.",
context);
}
} else {
printErrorMessage("We do not currently support this type: " + type);
printErrorMessage("We do not currently support this type: " + type,
context);
}
// THis can now be either an identifierList or an equalsAssignementList
......@@ -216,6 +205,10 @@ antlrcpp::Any qasm3_visitor::visitNoDesignatorDeclaration(
// Save the allocation, the store op
symbol_table.add_symbol(variable, allocation);
}
} else {
printErrorMessage("We do not yet support this no designator type: " +
context->noDesignatorType()->getText(),
context);
}
return 0;
......@@ -241,6 +234,7 @@ antlrcpp::Any qasm3_visitor::visitBitDeclaration(
// ;
auto location = get_location(builder, file_name, context);
// First case is indexIdentifierList, no initialization
std::size_t size = 1;
if (auto index_ident_list = context->indexIdentifierList()) {
for (auto idx_identifier : index_ident_list->indexIdentifier()) {
......@@ -256,14 +250,15 @@ antlrcpp::Any qasm3_visitor::visitBitDeclaration(
init_values.push_back(get_or_create_constant_integer_value(
0, location, builder.getI1Type(), symbol_table, builder));
init_indices.push_back(get_or_create_constant_index_value(
0, location, 64, symbol_table, builder));
i, location, 64, symbol_table, builder));
}
auto allocation = allocate_1d_memory_and_initialize(
location, size, builder.getI1Type(), init_values,
llvm::makeArrayRef(init_indices));
update_symbol_table(var_name, allocation);
symbol_table.add_symbol(var_name, allocation);
}
} else {
// Second case is indexEqualsAssignmentList, so bits with initialization
auto index_equals_list = context->indexEqualsAssignmentList();
for (int i = 0; i < index_equals_list->indexIdentifier().size(); i++) {
......@@ -283,7 +278,8 @@ antlrcpp::Any qasm3_visitor::visitBitDeclaration(
if (size != equals_expr.length()) {
printErrorMessage(
"Invalid initial string assignment for bit array, sizes do not "
"match.");
"match.",
context);
}
std::vector<mlir::Value> initial_values, indices;
......@@ -298,7 +294,7 @@ antlrcpp::Any qasm3_visitor::visitBitDeclaration(
location, size, builder.getI1Type(), initial_values,
llvm::makeArrayRef(indices));
update_symbol_table(var_name, allocation);
symbol_table.add_symbol(var_name, allocation);
}
}
......@@ -320,33 +316,46 @@ antlrcpp::Any qasm3_visitor::visitClassicalAssignment(
auto var_name = context->indexIdentifier(0)->Identifier()->getText();
auto ass_op = context->assignmentOperator(); // :)
// Make sure this is a valid symbol
if (!symbol_table.has_symbol(var_name)) {
printErrorMessage("invalid variable name in classical assignement: " +
var_name + ", " + ass_op->getText());
var_name + ", " + ass_op->getText(),
context);
}
// Make sure rhs is an expression
if (!context->expression()) {
printErrorMessage(
"We only can handle classicalAssignment expressions at this time, "
"no "
"indexIdentifiers.");
"indexIdentifiers.",
context);
}
// If the lhs is a const variable, throw an error
if (!symbol_table.is_variable_mutable(var_name)) {
printErrorMessage("Cannot change variable " + var_name +
", it has been marked const.");
printErrorMessage(
"Cannot change variable " + var_name + ", it has been marked const.",
context);
}
// Get the LHS symbol
auto lhs = symbol_table.get_symbol(var_name);
auto width = lhs.getType()
.cast<mlir::MemRefType>()
.getElementType()
.getIntOrFloatBitWidth();
if (!lhs.getType().isa<mlir::MemRefType>()) {
printErrorMessage("LHS in classical assignment must be a MemRefType.",
context);
}
// auto width = lhs.getType()
// .cast<mlir::MemRefType>()
// .getElementType()
// .getIntOrFloatBitWidth();
// Get the RHS value
qasm3_expression_generator exp_generator(
builder, symbol_table, file_name,
lhs.getType().cast<mlir::MemRefType>().getElementType());
exp_generator.visit(context->expression());
auto rhs = exp_generator.current_value;
......@@ -367,108 +376,119 @@ antlrcpp::Any qasm3_visitor::visitClassicalAssignment(
builder.create<mlir::StoreOp>(
location, rhs, lhs, llvm::makeArrayRef(std::vector<mlir::Value>{pos}));
return 0;
}
} else {
if (!lhs.getType().isa<mlir::MemRefType>()) {
printErrorMessage("cannot assign a value to a lhs that is not a memreftype.");
}
// Create a 0 index value for our Load and Store Ops
llvm::ArrayRef<mlir::Value> zero_index(get_or_create_constant_index_value(
0, location, 64, symbol_table, builder));
// Get the lhs and rhs types
auto lhs_type = lhs.getType().cast<mlir::MemRefType>().getElementType();
mlir::Type rhs_type = rhs.getType();
mlir::Value load_result_rhs = rhs;
if (rhs_type.isa<mlir::MemRefType>()) {
// if rhs is a memref, let's load its 0th index value
rhs_type = rhs_type.cast<mlir::MemRefType>().getElementType();
auto load_rhs = builder.create<mlir::LoadOp>(location, rhs, zero_index);
load_result_rhs = load_rhs.result();
}
llvm::ArrayRef<mlir::Value> zero_index(get_or_create_constant_index_value(
0, location, 64, symbol_table, builder));
// Load the LHS value
auto load = builder.create<mlir::LoadOp>(location, lhs, zero_index);
auto load_result = load.result();
auto lhs_type = lhs.getType().cast<mlir::MemRefType>().getElementType();
mlir::Type rhs_type = rhs.getType();
mlir::Value load_result_rhs = rhs;
if (rhs_type.isa<mlir::MemRefType>()) {
rhs_type = rhs_type.cast<mlir::MemRefType>().getElementType();
auto load_rhs = builder.create<mlir::LoadOp>(location, rhs, zero_index);
load_result_rhs = load_rhs.result();
// Check what the assignment op is...
mlir::Value current_value;
auto assignment_op = ass_op->getText();
if (assignment_op == "+=") {
// If either are floats, use float addition
if (lhs_type.isa<mlir::FloatType>() || rhs_type.isa<mlir::FloatType>()) {
current_value =
builder.create<mlir::AddFOp>(location, load_result, load_result_rhs);
} else if (lhs_type.isa<mlir::IntegerType>() &&
rhs_type.isa<mlir::IntegerType>()) {
// Else both must be integers to perform integer addition
current_value =
builder.create<mlir::AddIOp>(location, load_result, load_result_rhs);
} else {
printErrorMessage("Could not perform += for values of these types.",
context, {lhs, rhs});
}
// Load the LHS, has to load an existing allocated value
// %2 = load i32, i32* %1
// Create the zero int value
auto load = builder.create<mlir::LoadOp>(location, lhs, zero_index);
auto load_result = load.result();
mlir::Value current_value;
auto assignment_op = ass_op->getText();
if (assignment_op == "+=") {
if (lhs_type.isa<mlir::FloatType>() || rhs_type.isa<mlir::FloatType>()) {
current_value = builder.create<mlir::AddFOp>(location, load_result,
load_result_rhs);
} else if (lhs_type.isa<mlir::IntegerType>() &&
rhs_type.isa<mlir::IntegerType>()) {
current_value = builder.create<mlir::AddIOp>(location, load_result,
load_result_rhs);
}
llvm::ArrayRef<mlir::Value> zero_index2(
get_or_create_constant_index_value(0, location, 64, symbol_table,
builder));
builder.create<mlir::StoreOp>(location, current_value, lhs, zero_index2);
} else if (assignment_op == "-=") {
if (lhs_type.isa<mlir::FloatType>() || rhs_type.isa<mlir::FloatType>()) {
current_value = builder.create<mlir::SubFOp>(location, load_result,
load_result_rhs);
} else if (lhs_type.isa<mlir::IntegerType>() &&
rhs_type.isa<mlir::IntegerType>()) {
current_value = builder.create<mlir::SubIOp>(location, load_result,
load_result_rhs);
}
// Store the added value to the lhs
llvm::ArrayRef<mlir::Value> zero_index2(get_or_create_constant_index_value(
0, location, 64, symbol_table, builder));
builder.create<mlir::StoreOp>(location, current_value, lhs, zero_index2);
llvm::ArrayRef<mlir::Value> zero_index2(
get_or_create_constant_index_value(0, location, 64, symbol_table,
builder));
builder.create<mlir::StoreOp>(location, current_value, lhs, zero_index2);
} else if (assignment_op == "*=") {
if (lhs_type.isa<mlir::FloatType>() || rhs_type.isa<mlir::FloatType>()) {
current_value = builder.create<mlir::MulFOp>(location, load_result,
load_result_rhs);
} else if (lhs_type.isa<mlir::IntegerType>() &&
rhs_type.isa<mlir::IntegerType>()) {
current_value = builder.create<mlir::MulIOp>(location, load_result,
load_result_rhs);
}
} else if (assignment_op == "-=") {
// If either are floats, use float subtraction
if (lhs_type.isa<mlir::FloatType>() || rhs_type.isa<mlir::FloatType>()) {
current_value =
builder.create<mlir::SubFOp>(location, load_result, load_result_rhs);
} else if (lhs_type.isa<mlir::IntegerType>() &&
rhs_type.isa<mlir::IntegerType>()) {
// Else both must be integers to perform integer subtraction
current_value =
builder.create<mlir::SubIOp>(location, load_result, load_result_rhs);
} else {
printErrorMessage("Could not perform -= for values of these types.",
context, {lhs, rhs});
}
llvm::ArrayRef<mlir::Value> zero_index2(
get_or_create_constant_index_value(0, location, 64, symbol_table,
builder));
builder.create<mlir::StoreOp>(location, current_value, lhs, zero_index2);
} else if (assignment_op == "/=") {
if (lhs_type.isa<mlir::FloatType>() || rhs_type.isa<mlir::FloatType>()) {
current_value = builder.create<mlir::DivFOp>(location, load_result,
load_result_rhs);
} else if (lhs_type.isa<mlir::IntegerType>() &&
rhs_type.isa<mlir::IntegerType>()) {
current_value = builder.create<mlir::UnsignedDivIOp>(
location, load_result, load_result_rhs);
}
// Store the added value to the lhs
llvm::ArrayRef<mlir::Value> zero_index2(get_or_create_constant_index_value(
0, location, 64, symbol_table, builder));
builder.create<mlir::StoreOp>(location, current_value, lhs, zero_index2);
llvm::ArrayRef<mlir::Value> zero_index2(
get_or_create_constant_index_value(0, location, 64, symbol_table,
builder));
builder.create<mlir::StoreOp>(location, current_value, lhs, zero_index2);
} else if (assignment_op == "*=") {
// If either are floats, use float multiplication
if (lhs_type.isa<mlir::FloatType>() || rhs_type.isa<mlir::FloatType>()) {
current_value =
builder.create<mlir::MulFOp>(location, load_result, load_result_rhs);
} else if (lhs_type.isa<mlir::IntegerType>() &&
rhs_type.isa<mlir::IntegerType>()) {
// Else both must be integers to perform integer subtraction
current_value =
builder.create<mlir::MulIOp>(location, load_result, load_result_rhs);
} else {
printErrorMessage("Could not perform *= for values of these types.",
context, {lhs, rhs});
}
} else if (assignment_op == "^=") {
// Store the added value to the lhs
llvm::ArrayRef<mlir::Value> zero_index2(get_or_create_constant_index_value(
0, location, 64, symbol_table, builder));
builder.create<mlir::StoreOp>(location, current_value, lhs, zero_index2);
} else if (assignment_op == "/=") {
if (lhs_type.isa<mlir::FloatType>() || rhs_type.isa<mlir::FloatType>()) {
current_value =
builder.create<mlir::XOrOp>(location, load_result, load_result_rhs);
llvm::ArrayRef<mlir::Value> zero_index2(
get_or_create_constant_index_value(0, location, 64, symbol_table,
builder));
builder.create<mlir::StoreOp>(location, current_value, lhs, zero_index2);
} else if (assignment_op == "=") {
builder.create<mlir::StoreOp>(
location, rhs, lhs,
get_or_create_constant_index_value(0, location, 64, symbol_table,
builder));
builder.create<mlir::DivFOp>(location, load_result, load_result_rhs);
} else if (lhs_type.isa<mlir::IntegerType>() &&
rhs_type.isa<mlir::IntegerType>()) {
current_value = builder.create<mlir::UnsignedDivIOp>(
location, load_result, load_result_rhs);
} else {
printErrorMessage(ass_op->getText() + " not yet supported for this type.",
context);
printErrorMessage("Could not perform /= for values of these types.",
context, {lhs, rhs});
}