Commit 890d5e2d authored by Stephan Herhut's avatar Stephan Herhut
Browse files

[MLIR][GPU] Disallow llvm tanh intrinsics when lowering to NVVM/ROCm.

Summary:
The lowering to NVVM and ROCm handles tanh operations differently by
mapping them to NVVM/ROCm specific intrinsics. This conflicts with
the lowering to LLVM, which uses the default llvm intrinsic. This change
declares the LLVM intrinsics to be illegal, hence disallowing the
correspondign rewrite.

Differential Revision: https://reviews.llvm.org/D74389
parent 98c940bf
Loading
Loading
Loading
Loading
+18 −0
Original line number Diff line number Diff line
@@ -95,6 +95,24 @@ private:
  const std::string f64Func;
};

namespace gpu {
/// Returns a predicate to be used with addDynamicallyLegalOp. The predicate
/// returns false for calls to the provided intrinsics and true otherwise.
inline std::function<bool(Operation *)>
filterIllegalLLVMIntrinsics(ArrayRef<StringRef> intrinsics, MLIRContext *ctx) {
  SmallVector<StringRef, 4> illegalIds(intrinsics.begin(), intrinsics.end());
  return [illegalIds](Operation *op) -> bool {
    LLVM::CallOp callOp = dyn_cast<LLVM::CallOp>(op);
    if (!callOp || !callOp.callee())
      return true;
    StringRef callee = callOp.callee().getValue();
    return !llvm::any_of(illegalIds, [callee](StringRef intrinsic) {
      return callee.equals(intrinsic);
    });
  };
}
} // namespace gpu

} // namespace mlir

#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
+2 −0
Original line number Diff line number Diff line
@@ -695,6 +695,8 @@ public:
    target.addIllegalOp<FuncOp>();
    target.addLegalDialect<LLVM::LLVMDialect>();
    target.addLegalDialect<NVVM::NVVMDialect>();
    target.addDynamicallyLegalOp<mlir::LLVM::CallOp>(
        gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
    // TODO(csigg): Remove once we support replacing non-root ops.
    target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
    if (failed(applyPartialConversion(m, target, patterns, &converter)))
+3 −2
Original line number Diff line number Diff line
@@ -65,8 +65,9 @@ public:
    target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
    target.addIllegalOp<LLVM::FAbsOp, LLVM::FCeilOp, LLVM::CosOp,
                        LLVM::ExpOp>();
    target.addDynamicallyLegalOp<FuncOp>(
        [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
    target.addDynamicallyLegalOp<LLVM::CallOp>(
        gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
    target.addIllegalOp<FuncOp>();
    if (failed(applyPartialConversion(m, target, patterns, &converter)))
      signalPassFailure();
  }