Commit de2e358c authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

general expression handler updates

parent 923c9058
Loading
Loading
Loading
Loading
+62 −30
Original line number Diff line number Diff line
@@ -308,10 +308,10 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression(
  if (auto has_sub_additive_expr = ctx->additiveExpression()) {
    auto bin_op = ctx->binary_op->getText();

    visitChildren(has_sub_additive_expr);
    visit(has_sub_additive_expr);
    auto lhs = current_value;

    visitChildren(ctx->multiplicativeExpression());
    visit(ctx->multiplicativeExpression());
    auto rhs = current_value;
    
    if (lhs.getType().isa<mlir::MemRefType>()) {
@@ -358,21 +358,23 @@ antlrcpp::Any qasm3_expression_generator::visitAdditiveExpression(
                              ctx, {lhs, rhs});
          }
        }
        // else {
        //   printErrorMessage("Could not perform addition, incompatible types:
        //   " +
        //                     ctx->getText());
        // }

        createOp<mlir::AddFOp>(location, lhs, rhs);
      } else if (lhs.getType().isa<mlir::IntegerType>() &&
                 rhs.getType().isa<mlir::IntegerType>()) {
        if (lhs.getType().getIntOrFloatBitWidth() < rhs.getType().getIntOrFloatBitWidth()) {
          lhs = builder.create<mlir::ZeroExtendIOp>(location, lhs,rhs.getType() );
        }
        if (rhs.getType().getIntOrFloatBitWidth() < lhs.getType().getIntOrFloatBitWidth()) {
          rhs = builder.create<mlir::ZeroExtendIOp>(location, rhs,lhs.getType() );
        }
        createOp<mlir::AddIOp>(location, lhs, rhs).result();
      } else {
        printErrorMessage("Could not perform addition, incompatible types: ",
                          ctx, {lhs, rhs});
      }
    } else if (bin_op == "-") {

      if (lhs.getType().isa<mlir::FloatType>() ||
          rhs.getType().isa<mlir::FloatType>()) {
        // One of these at least is a float, need to have
@@ -601,15 +603,31 @@ antlrcpp::Any qasm3_expression_generator::visitXOrExpression(
antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression(
    qasm3Parser::MultiplicativeExpressionContext* ctx) {
  auto location = get_location(builder, file_name, ctx);

  if (auto mult_expr = ctx->multiplicativeExpression()) {
    auto bin_op = ctx->binary_op->getText();


    visitExpressionTerminator(mult_expr->expressionTerminator());
    auto lhs = current_value;

    visitExpressionTerminator(ctx->expressionTerminator());
    auto rhs = current_value;

    if (lhs.getType().isa<mlir::MemRefType>()) {
      lhs = builder.create<mlir::LoadOp>(
          location, lhs,
          get_or_create_constant_index_value(0, location, 64, symbol_table,
                                             builder));
    }

    if (rhs.getType().isa<mlir::MemRefType>()) {
      rhs = builder.create<mlir::LoadOp>(
          location, rhs,
          get_or_create_constant_index_value(0, location, 64, symbol_table,
                                             builder));
    }

    if (bin_op == "*") {
      if (lhs.getType().isa<mlir::FloatType>() ||
          rhs.getType().isa<mlir::FloatType>()) {
@@ -640,11 +658,6 @@ antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression(
                              ctx, {lhs, rhs});
          }
        }
        // else {
        //   printErrorMessage(
        //       "Could not perform multiplication, incompatible types: ", ctx,
        //     {lhs, rhs});
        // }

        createOp<mlir::MulFOp>(location, lhs, rhs);
      } else if (lhs.getType().isa<mlir::IntegerType>() &&
@@ -685,11 +698,6 @@ antlrcpp::Any qasm3_expression_generator::visitMultiplicativeExpression(
                "Must cast rhs to float, but it is not constant.");
          }
        }
        // else {
        //   std::cout << "MADE IT HERE\n";
        //   printErrorMessage("Could not perform division, incompatible types:
        //   ", ctx, {lhs, rhs});
        // }

        createOp<mlir::DivFOp>(location, lhs, rhs);
      } else if (lhs.getType().isa<mlir::IntegerType>() &&
@@ -742,6 +750,11 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
    return 0;
  }

  if (ctx->LPAREN() && ctx->RPAREN()) {
    visit(ctx->expression());
    return 0;
  }

  if (ctx->Constant()) {
    auto const_str = ctx->Constant()->getText();
    // std::cout << ctx->Constant()->getText() << "\n";
@@ -797,6 +810,18 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
          0, location, builder.getIntegerType(1), symbol_table, builder);
    } else {
      value = symbol_table.get_symbol(id->getText());
      // If we are not in global scope and this value is
      // marked const, then I want to re-create it and return
      // that, this will mimic using global constants in downstream
      // scopes
      // if (symbol_table.get_current_scope() != 0) {
      //   auto var_attrs = symbol_table.get_variable_attributes(id->getText());
      //   if (!var_attrs.empty() && std::find(var_attrs.begin(), var_attrs.end(),
      //                                       "const") != std::end(var_attrs)) {
      //     auto constant_val = value.getDefiningOp<mlir::ConstantOp>().value();
      //     value = builder.create<mlir::ConstantOp>(location, constant_val);
      //   }
      // }
    }
    update_current_value(value);

@@ -1145,15 +1170,13 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
      qubit_expr_list_idx = 1;

      for (auto expression : expression_list[0]->expression()) {
        // std::cout << "Subcall expr: " << expression->getText() << "\n";
        // add parameter values:
        // FIXME THIS SHOULD MATCH TYPES for FUNCTION
        auto value = std::stod(expression->getText());
        auto float_attr = mlir::FloatAttr::get(builder.getF64Type(), value);
        mlir::Value val =
            builder.create<mlir::ConstantOp>(location, float_attr);
        operands.push_back(val);
        qasm3_expression_generator param_exp_generator(builder, symbol_table,
                                                       file_name);
        param_exp_generator.visit(expression);
        operands.push_back(param_exp_generator.current_value);
      }

      // Here we add all global variables
    }

    for (auto expression : expression_list[qubit_expr_list_idx]->expression()) {
@@ -1165,8 +1188,17 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
      operands.push_back(qubit_exp_generator.current_value);
    }
    auto call_op = builder.create<mlir::CallOp>(location, func,
                                                llvm::makeArrayRef(operands));
    update_current_value(call_op.getResult(0));
                                                llvm::makeArrayRef(operands)).getResult(0);
    // If RHS is a memref<1xTYPE> then lets load it first
    if (auto rhs_mem = call_op.getType().dyn_cast_or_null<mlir::MemRefType>()) {
      call_op = builder.create<mlir::LoadOp>(
          location, call_op,
          get_or_create_constant_index_value(0, location, 64,
                                             symbol_table, builder));
    }
    // printErrorMessage("HELLO should we return the loaded result here?", ctx, {call_op.getResult(0)});

    update_current_value(call_op);

    return 0;
  } else if (auto kernel_call = ctx->kernelCall()) {
@@ -1179,8 +1211,8 @@ antlrcpp::Any qasm3_expression_generator::visitExpressionTerminator(
    std::vector<mlir::Value> operands;
    auto expression_list = kernel_call->expressionList()->expression();
    for (auto expression : expression_list) {
      qasm3_expression_generator param_exp_generator(
          builder, symbol_table, file_name);
      qasm3_expression_generator param_exp_generator(builder, symbol_table,
                                                     file_name);
      param_exp_generator.visit(expression);
      operands.push_back(param_exp_generator.current_value);
    }