Loading mlir/parsers/qasm3/utils/expression_handler.cpp +62 −30 Original line number Diff line number Diff line Loading @@ -308,10 +308,10 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression( if (auto has_sub_additive_expr = ctx->additiveExpression()) { auto bin_op = ctx->binary_op->getText(); visitChildren(has_sub_additive_expr); visit(has_sub_additive_expr); auto lhs = current_value; visitChildren(ctx->multiplicativeExpression()); visit(ctx->multiplicativeExpression()); auto rhs = current_value; if (lhs.getType().isa<mlir::MemRefType>()) { Loading Loading @@ -358,21 +358,23 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression( ctx, {lhs, rhs}); } } // else { // printErrorMessage("Could not perform addition, incompatible types: // " + // ctx->getText()); // } createOp<mlir::AddFOp>(location, lhs, rhs); } else if (lhs.getType().isa<mlir::IntegerType>() && rhs.getType().isa<mlir::IntegerType>()) { if (lhs.getType().getIntOrFloatBitWidth() < rhs.getType().getIntOrFloatBitWidth()) { lhs = builder.create<mlir::ZeroExtendIOp>(location, lhs,rhs.getType() ); } if (rhs.getType().getIntOrFloatBitWidth() < lhs.getType().getIntOrFloatBitWidth()) { rhs = builder.create<mlir::ZeroExtendIOp>(location, rhs,lhs.getType() ); } createOp<mlir::AddIOp>(location, lhs, rhs).result(); } else { printErrorMessage("Could not perform addition, incompatible types: ", ctx, {lhs, rhs}); } } else if (bin_op == "-") { if (lhs.getType().isa<mlir::FloatType>() || rhs.getType().isa<mlir::FloatType>()) { // One of these at least is a float, need to have Loading Loading @@ -601,15 +603,31 @@ antlrcpp::Any qasm3_expression_generator::visitXOrExpression( antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression( qasm3Parser::MultiplicativeExpressionContext* ctx) { auto location = get_location(builder, file_name, ctx); if (auto mult_expr = ctx->multiplicativeExpression()) { auto bin_op = ctx->binary_op->getText(); visitExpressionTerminator(mult_expr->expressionTerminator()); auto lhs = current_value; visitExpressionTerminator(ctx->expressionTerminator()); auto rhs = current_value; if (lhs.getType().isa<mlir::MemRefType>()) { lhs = builder.create<mlir::LoadOp>( location, lhs, get_or_create_constant_index_value(0, location, 64, symbol_table, builder)); } if (rhs.getType().isa<mlir::MemRefType>()) { rhs = builder.create<mlir::LoadOp>( location, rhs, get_or_create_constant_index_value(0, location, 64, symbol_table, builder)); } if (bin_op == "*") { if (lhs.getType().isa<mlir::FloatType>() || rhs.getType().isa<mlir::FloatType>()) { Loading Loading @@ -640,11 +658,6 @@ antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression( ctx, {lhs, rhs}); } } // else { // printErrorMessage( // "Could not perform multiplication, incompatible types: ", ctx, // {lhs, rhs}); // } createOp<mlir::MulFOp>(location, lhs, rhs); } else if (lhs.getType().isa<mlir::IntegerType>() && Loading Loading @@ -685,11 +698,6 @@ antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression( "Must cast rhs to float, but it is not constant."); } } // else { // std::cout << "MADE IT HERE\n"; // printErrorMessage("Could not perform division, incompatible types: // ", ctx, {lhs, rhs}); // } createOp<mlir::DivFOp>(location, lhs, rhs); } else if (lhs.getType().isa<mlir::IntegerType>() && Loading Loading @@ -742,6 +750,11 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( return 0; } if (ctx->LPAREN() && ctx->RPAREN()) { visit(ctx->expression()); return 0; } if (ctx->Constant()) { auto const_str = ctx->Constant()->getText(); // std::cout << ctx->Constant()->getText() << "\n"; Loading Loading @@ -797,6 +810,18 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( 0, location, builder.getIntegerType(1), symbol_table, builder); } else { value = symbol_table.get_symbol(id->getText()); // If we are not in global scope and this value is // marked const, then I want to re-create it and return // that, this will mimic using global constants in downstream // scopes // if (symbol_table.get_current_scope() != 0) { // auto var_attrs = symbol_table.get_variable_attributes(id->getText()); // if (!var_attrs.empty() && std::find(var_attrs.begin(), var_attrs.end(), // "const") != std::end(var_attrs)) { // auto constant_val = value.getDefiningOp<mlir::ConstantOp>().value(); // value = builder.create<mlir::ConstantOp>(location, constant_val); // } // } } update_current_value(value); Loading Loading @@ -1145,15 +1170,13 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( qubit_expr_list_idx = 1; for (auto expression : expression_list[0]->expression()) { // std::cout << "Subcall expr: " << expression->getText() << "\n"; // add parameter values: // FIXME THIS SHOULD MATCH TYPES for FUNCTION auto value = std::stod(expression->getText()); auto float_attr = mlir::FloatAttr::get(builder.getF64Type(), value); mlir::Value val = builder.create<mlir::ConstantOp>(location, float_attr); operands.push_back(val); qasm3_expression_generator param_exp_generator(builder, symbol_table, file_name); param_exp_generator.visit(expression); operands.push_back(param_exp_generator.current_value); } // Here we add all global variables } for (auto expression : expression_list[qubit_expr_list_idx]->expression()) { Loading @@ -1165,8 +1188,17 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( operands.push_back(qubit_exp_generator.current_value); } auto call_op = builder.create<mlir::CallOp>(location, func, llvm::makeArrayRef(operands)); update_current_value(call_op.getResult(0)); llvm::makeArrayRef(operands)).getResult(0); // If RHS is a memref<1xTYPE> then lets load it first if (auto rhs_mem = call_op.getType().dyn_cast_or_null<mlir::MemRefType>()) { call_op = builder.create<mlir::LoadOp>( location, call_op, get_or_create_constant_index_value(0, location, 64, symbol_table, builder)); } // printErrorMessage("HELLO should we return the loaded result here?", ctx, {call_op.getResult(0)}); update_current_value(call_op); return 0; } else if (auto kernel_call = ctx->kernelCall()) { Loading @@ -1179,8 +1211,8 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( std::vector<mlir::Value> operands; auto expression_list = kernel_call->expressionList()->expression(); for (auto expression : expression_list) { qasm3_expression_generator param_exp_generator( builder, symbol_table, file_name); qasm3_expression_generator param_exp_generator(builder, symbol_table, file_name); param_exp_generator.visit(expression); operands.push_back(param_exp_generator.current_value); } Loading Loading
mlir/parsers/qasm3/utils/expression_handler.cpp +62 −30 Original line number Diff line number Diff line Loading @@ -308,10 +308,10 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression( if (auto has_sub_additive_expr = ctx->additiveExpression()) { auto bin_op = ctx->binary_op->getText(); visitChildren(has_sub_additive_expr); visit(has_sub_additive_expr); auto lhs = current_value; visitChildren(ctx->multiplicativeExpression()); visit(ctx->multiplicativeExpression()); auto rhs = current_value; if (lhs.getType().isa<mlir::MemRefType>()) { Loading Loading @@ -358,21 +358,23 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression( ctx, {lhs, rhs}); } } // else { // printErrorMessage("Could not perform addition, incompatible types: // " + // ctx->getText()); // } createOp<mlir::AddFOp>(location, lhs, rhs); } else if (lhs.getType().isa<mlir::IntegerType>() && rhs.getType().isa<mlir::IntegerType>()) { if (lhs.getType().getIntOrFloatBitWidth() < rhs.getType().getIntOrFloatBitWidth()) { lhs = builder.create<mlir::ZeroExtendIOp>(location, lhs,rhs.getType() ); } if (rhs.getType().getIntOrFloatBitWidth() < lhs.getType().getIntOrFloatBitWidth()) { rhs = builder.create<mlir::ZeroExtendIOp>(location, rhs,lhs.getType() ); } createOp<mlir::AddIOp>(location, lhs, rhs).result(); } else { printErrorMessage("Could not perform addition, incompatible types: ", ctx, {lhs, rhs}); } } else if (bin_op == "-") { if (lhs.getType().isa<mlir::FloatType>() || rhs.getType().isa<mlir::FloatType>()) { // One of these at least is a float, need to have Loading Loading @@ -601,15 +603,31 @@ antlrcpp::Any qasm3_expression_generator::visitXOrExpression( antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression( qasm3Parser::MultiplicativeExpressionContext* ctx) { auto location = get_location(builder, file_name, ctx); if (auto mult_expr = ctx->multiplicativeExpression()) { auto bin_op = ctx->binary_op->getText(); visitExpressionTerminator(mult_expr->expressionTerminator()); auto lhs = current_value; visitExpressionTerminator(ctx->expressionTerminator()); auto rhs = current_value; if (lhs.getType().isa<mlir::MemRefType>()) { lhs = builder.create<mlir::LoadOp>( location, lhs, get_or_create_constant_index_value(0, location, 64, symbol_table, builder)); } if (rhs.getType().isa<mlir::MemRefType>()) { rhs = builder.create<mlir::LoadOp>( location, rhs, get_or_create_constant_index_value(0, location, 64, symbol_table, builder)); } if (bin_op == "*") { if (lhs.getType().isa<mlir::FloatType>() || rhs.getType().isa<mlir::FloatType>()) { Loading Loading @@ -640,11 +658,6 @@ antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression( ctx, {lhs, rhs}); } } // else { // printErrorMessage( // "Could not perform multiplication, incompatible types: ", ctx, // {lhs, rhs}); // } createOp<mlir::MulFOp>(location, lhs, rhs); } else if (lhs.getType().isa<mlir::IntegerType>() && Loading Loading @@ -685,11 +698,6 @@ antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression( "Must cast rhs to float, but it is not constant."); } } // else { // std::cout << "MADE IT HERE\n"; // printErrorMessage("Could not perform division, incompatible types: // ", ctx, {lhs, rhs}); // } createOp<mlir::DivFOp>(location, lhs, rhs); } else if (lhs.getType().isa<mlir::IntegerType>() && Loading Loading @@ -742,6 +750,11 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( return 0; } if (ctx->LPAREN() && ctx->RPAREN()) { visit(ctx->expression()); return 0; } if (ctx->Constant()) { auto const_str = ctx->Constant()->getText(); // std::cout << ctx->Constant()->getText() << "\n"; Loading Loading @@ -797,6 +810,18 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( 0, location, builder.getIntegerType(1), symbol_table, builder); } else { value = symbol_table.get_symbol(id->getText()); // If we are not in global scope and this value is // marked const, then I want to re-create it and return // that, this will mimic using global constants in downstream // scopes // if (symbol_table.get_current_scope() != 0) { // auto var_attrs = symbol_table.get_variable_attributes(id->getText()); // if (!var_attrs.empty() && std::find(var_attrs.begin(), var_attrs.end(), // "const") != std::end(var_attrs)) { // auto constant_val = value.getDefiningOp<mlir::ConstantOp>().value(); // value = builder.create<mlir::ConstantOp>(location, constant_val); // } // } } update_current_value(value); Loading Loading @@ -1145,15 +1170,13 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( qubit_expr_list_idx = 1; for (auto expression : expression_list[0]->expression()) { // std::cout << "Subcall expr: " << expression->getText() << "\n"; // add parameter values: // FIXME THIS SHOULD MATCH TYPES for FUNCTION auto value = std::stod(expression->getText()); auto float_attr = mlir::FloatAttr::get(builder.getF64Type(), value); mlir::Value val = builder.create<mlir::ConstantOp>(location, float_attr); operands.push_back(val); qasm3_expression_generator param_exp_generator(builder, symbol_table, file_name); param_exp_generator.visit(expression); operands.push_back(param_exp_generator.current_value); } // Here we add all global variables } for (auto expression : expression_list[qubit_expr_list_idx]->expression()) { Loading @@ -1165,8 +1188,17 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( operands.push_back(qubit_exp_generator.current_value); } auto call_op = builder.create<mlir::CallOp>(location, func, llvm::makeArrayRef(operands)); update_current_value(call_op.getResult(0)); llvm::makeArrayRef(operands)).getResult(0); // If RHS is a memref<1xTYPE> then lets load it first if (auto rhs_mem = call_op.getType().dyn_cast_or_null<mlir::MemRefType>()) { call_op = builder.create<mlir::LoadOp>( location, call_op, get_or_create_constant_index_value(0, location, 64, symbol_table, builder)); } // printErrorMessage("HELLO should we return the loaded result here?", ctx, {call_op.getResult(0)}); update_current_value(call_op); return 0; } else if (auto kernel_call = ctx->kernelCall()) { Loading @@ -1179,8 +1211,8 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator( std::vector<mlir::Value> operands; auto expression_list = kernel_call->expressionList()->expression(); for (auto expression : expression_list) { qasm3_expression_generator param_exp_generator( builder, symbol_table, file_name); qasm3_expression_generator param_exp_generator(builder, symbol_table, file_name); param_exp_generator.visit(expression); operands.push_back(param_exp_generator.current_value); } Loading