Commit cbf08d0f authored by Marcello Maggioni's avatar Marcello Maggioni
Browse files

[mlir] Fix LLVM intrinsic convesion generator for overloadable types.

Summary:
If an intrinsic has overloadable types like llvm_anyint_ty or
llvm_anyfloat_ty then to getDeclaration() we need to pass a list
of the types that are "undefined" essentially concretizing them.

This patch add support for deriving such types from the MLIR op
that has been matched.

Reviewers: andydavis1, ftynse, nicolasvasilache, antiagainst

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72974
parent 2d77e0b9
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -6,16 +6,21 @@
// includes from the main file to avoid unnecessary dependencies and decrease
// the test cost. The command-line flags further ensure a specific intrinsic is
// processed and we only check the ouptut below.
// We also verify emission of type specialization for overloadable intrinsics.
//
// RUN: cat %S/../../../llvm/include/llvm/IR/Intrinsics.td \
// RUN: | grep -v "llvm/IR/Intrinsics" \
// RUN: | mlir-tblgen -gen-llvmir-intrinsics -I %S/../../../llvm/include/ --llvmir-intrinsics-filter=vastart \
// RUN: | mlir-tblgen -gen-llvmir-intrinsics -I %S/../../../llvm/include/ --llvmir-intrinsics-filter=is_constant \
// RUN: | FileCheck %s

// CHECK-LABEL: def LLVM_vastart
// CHECK-LABEL: def LLVM_is_constant
// CHECK: LLVM_Op<"intr
// CHECK: Arguments<(ins
// CHECK: Results<(outs
// CHECK: llvm::Function *fn = llvm::Intrinsic::getDeclaration(
// CHECK:        module, llvm::Intrinsic::is_constant, {
// CHECK:        opInst.getOperand(0).getType().cast<LLVM::LLVMType>().getUnderlyingType(),
// CHECK: });

//---------------------------------------------------------------------------//

+56 −1
Original line number Diff line number Diff line
@@ -14,7 +14,9 @@
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/GenInfo.h"

#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MachineValueType.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
@@ -30,6 +32,38 @@ static llvm::cl::opt<std::string>
                              "substring in their record name"),
               llvm::cl::cat(IntrinsicGenCat));

// Used to represent the indices of overloadable operands/results.
using IndicesTy = llvm::SmallBitVector;

/// Return a CodeGen value type entry from a type record.
static llvm::MVT::SimpleValueType getValueType(const llvm::Record *rec) {
  return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt(
      "Value");
}

/// Return the indices of the definitions in a list of definitions that
/// represent overloadable types
static IndicesTy getOverloadableTypeIdxs(const llvm::Record &record,
                                         const char *listName) {
  auto results = record.getValueAsListOfDefs(listName);
  IndicesTy overloadedOps(results.size());
  for (auto r : llvm::enumerate(results)) {
    llvm::MVT::SimpleValueType vt = getValueType(r.value());
    switch (vt) {
    case llvm::MVT::iAny:
    case llvm::MVT::fAny:
    case llvm::MVT::Any:
    case llvm::MVT::iPTRAny:
    case llvm::MVT::vAny:
      overloadedOps.set(r.index());
      break;
    default:
      continue;
    }
  }
  return overloadedOps;
}

namespace {
/// A wrapper for LLVM's Tablegen class `Intrinsic` that provides accessors to
/// the fields of the record.
@@ -108,6 +142,14 @@ public:
    return false;
  }

  IndicesTy getOverloadableOperandsIdxs() const {
    return getOverloadableTypeIdxs(record, fieldOperands);
  }

  IndicesTy getOverloadableResultsIdxs() const {
    return getOverloadableTypeIdxs(record, fieldResults);
  }

private:
  /// Names of the fields in the Intrinsic LLVM Tablegen class.
  const char *fieldName = "LLVMName";
@@ -122,10 +164,23 @@ private:
/// Emits C++ code constructing an LLVM IR intrinsic given the generated MLIR
/// operation.  In LLVM IR, intrinsics are constructed as function calls.
static void emitBuilder(const LLVMIntrinsic &intr, llvm::raw_ostream &os) {
  auto overloadedRes = intr.getOverloadableResultsIdxs();
  auto overloadedOps = intr.getOverloadableOperandsIdxs();
  os << "    llvm::Module *module = builder.GetInsertBlock()->getModule();\n";
  os << "    llvm::Function *fn = llvm::Intrinsic::getDeclaration(\n";
  os << "        module, llvm::Intrinsic::" << intr.getProperRecordName()
     << ");\n";
     << ", {";
  for (unsigned idx : overloadedRes.set_bits()) {
    os << "\n        opInst.getResult(" << idx << ").getType()"
       << ".cast<LLVM::LLVMType>().getUnderlyingType(),";
  }
  for (unsigned idx : overloadedOps.set_bits()) {
    os << "\n        opInst.getOperand(" << idx << ").getType()"
       << ".cast<LLVM::LLVMType>().getUnderlyingType(),";
  }
  if (overloadedRes.any() || overloadedOps.any())
    os << "\n  ";
  os << "});\n";
  os << "    auto operands = llvm::to_vector<8, Value *>(\n";
  os << "        opInst.operand_begin(), opInst.operand_end());\n";
  os << "    " << (intr.getNumResults() > 0 ? "$res = " : "")