Commit cfa82f77 authored by K-Wu's avatar K-Wu Committed by Kun Wu
Browse files

[mlir][sparse][gpu] introduce flag that controls host to device copy...

[mlir][sparse][gpu] introduce flag that controls host to device copy strategies (regular dma default)

Differential Revision: https://reviews.llvm.org/D155352
parent 9a806551
Loading
Loading
Loading
Loading
+18 −2
Original line number Diff line number Diff line
@@ -52,6 +52,21 @@ struct SparseCompilerOptions
              mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
              "any-storage-any-loop",
              "Enable sparse parallelization for any storage and loop."))};
  PassOptions::Option<mlir::GPUDataTransferStrategy> gpuDataTransfer{
      *this, "gpu-data-transfer-strategy",
      ::llvm::cl::desc(
          "Set the data transfer strategy between the host and the GPUs"),
      ::llvm::cl::init(mlir::GPUDataTransferStrategy::kRegularDMA),
      llvm::cl::values(
          clEnumValN(mlir::GPUDataTransferStrategy::kRegularDMA, "regular-dma",
                     "Default option: malloc on host without additional "
                     "options or care and then use DMA to copy the data"),
          clEnumValN(mlir::GPUDataTransferStrategy::kPinnedDMA, "pinned-dma",
                     "Based on the default option, pin the host memory to "
                     "accelerate the data transfer"),
          clEnumValN(mlir::GPUDataTransferStrategy::kZeroCopy, "zero-copy",
                     "Use zero-copy to perform the data transfer from the host "
                     "to the GPU"))};

  PassOptions::Option<bool> enableIndexReduction{
      *this, "enable-index-reduction",
@@ -138,8 +153,9 @@ struct SparseCompilerOptions

  /// Projects out the options for `createSparsificationPass`.
  SparsificationOptions sparsificationOptions() const {
    return SparsificationOptions(parallelization, enableIndexReduction,
                                 enableGPULibgen, enableRuntimeLibrary);
    return SparsificationOptions(parallelization, gpuDataTransfer,
                                 enableIndexReduction, enableGPULibgen,
                                 enableRuntimeLibrary);
  }

  /// Projects out the options for `createSparseTensorConversionPass`.
+13 −6
Original line number Diff line number Diff line
@@ -44,19 +44,26 @@ enum class SparseParallelizationStrategy {
  // TODO: support reduction parallelization too?
};

// TODO : Zero copy is disabled due to correctness bugs.Tracker #64316
enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };

#define GEN_PASS_DECL
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"

/// Options for the Sparsification pass.
struct SparsificationOptions {
  SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc,
  SparsificationOptions(SparseParallelizationStrategy p,
                        GPUDataTransferStrategy t, bool idxReduc,
                        bool gpuLibgen, bool enableRT)
      : parallelizationStrategy(p), enableIndexReduction(idxReduc),
        enableGPULibgen(gpuLibgen), enableRuntimeLibrary(enableRT) {}
      : parallelizationStrategy(p), gpuDataTransferStrategy(t),
        enableIndexReduction(idxReduc), enableGPULibgen(gpuLibgen),
        enableRuntimeLibrary(enableRT) {}
  SparsificationOptions()
      : SparsificationOptions(SparseParallelizationStrategy::kNone, false,
      : SparsificationOptions(SparseParallelizationStrategy::kNone,
                              GPUDataTransferStrategy::kRegularDMA, false,
                              false, true) {}
  SparseParallelizationStrategy parallelizationStrategy;
  GPUDataTransferStrategy gpuDataTransferStrategy;
  bool enableIndexReduction;
  bool enableGPULibgen;
  bool enableRuntimeLibrary;
@@ -211,8 +218,8 @@ std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
                                      unsigned numThreads);

void populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
                                     bool enableRT);
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT,
                                     GPUDataTransferStrategy gpuDataTransfer);

std::unique_ptr<Pass> createSparseGPUCodegenPass();
std::unique_ptr<Pass> createSparseGPUCodegenPass(unsigned numThreads);
+14 −0
Original line number Diff line number Diff line
@@ -102,6 +102,19 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
             clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
                        "any-storage-any-loop",
                        "Enable sparse parallelization for any storage and loop."))}]>,
    Option<"gpuDataTransfer", "gpu-data-transfer-strategy", "mlir::GPUDataTransferStrategy",
            "mlir::GPUDataTransferStrategy::kRegularDMA",
            "Set the data transfer strategy", [{llvm::cl::values(
               clEnumValN(mlir::GPUDataTransferStrategy::kRegularDMA,
                     "regular-dma",
                     "Default option: malloc on host without additional "
                     "options or care and then use DMA to copy the data"),
          clEnumValN(mlir::GPUDataTransferStrategy::kPinnedDMA, "pinned-dma",
                     "Based on the default option, pin the host memory to "
                     "accelerate the data transfer"),
          clEnumValN(mlir::GPUDataTransferStrategy::kZeroCopy, "zero-copy",
                     "Use zero-copy to perform the data transfer from the host "
                     "to the GPU"))}]>,
    Option<"enableGPULibgen", "enable-gpu-libgen", "bool",
           "false",
           "Enable GPU acceleration by means of direct library calls (like cuSPARSE)">,
@@ -110,6 +123,7 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
  ];
}


def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
  let summary = "Applies sparse tensor rewriting rules after sparsification";
  let description = [{
+150 −33
Original line number Diff line number Diff line
@@ -461,14 +461,18 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
}

/// Match and rewrite SpMV kernel.
static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
                                 linalg::GenericOp op, bool enableRT) {
static LogicalResult
rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
            GPUDataTransferStrategy gpuDataTransferStrategy) {
  Location loc = op.getLoc();
  Value a = op.getOperand(0);
  Value x = op.getOperand(1);
  Value y = op.getOperand(2); // we have y = Ax
  SmallVector<Value> tokens;

  bool isZeroCopy =
      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;

  // Only admissible sparse matrix format and dense vectors.
  bool isCOO = false;
  SparseTensorType aTp = getSparseTensorType(a);
@@ -487,12 +491,27 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
  Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
  Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
  Value memV = genToValues(rewriter, loc, a);
  Value memX, memY;
  Value castR, castC, castV, castX, castY;
  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
    memX = genTensorToMemref(rewriter, loc, x);
    memY = genTensorToMemref(rewriter, loc, y);
    castR = genHostRegisterMemref(rewriter, loc, memR);
    if (memC)
      castC = genHostRegisterMemref(rewriter, loc, memC);
    castV = genHostRegisterMemref(rewriter, loc, memV);
    castX = genHostRegisterMemref(rewriter, loc, memX);
    castY = genHostRegisterMemref(rewriter, loc, memY);
  }

  Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
  Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
  Value valA = genAllocCopy(rewriter, loc, memV, tokens);
  Value memX = genTensorToMemref(rewriter, loc, x);
  Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
  Value memY = genTensorToMemref(rewriter, loc, y);
  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
    memX = genTensorToMemref(rewriter, loc, x);
  Value vecX = isZeroCopy ? memX : genAllocCopy(rewriter, loc, memX, tokens);
  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
    memY = genTensorToMemref(rewriter, loc, y);
  Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
  genBlockingWait(rewriter, loc, tokens);
  tokens.clear();
@@ -546,11 +565,20 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
    token = genDeallocMemRef(rewriter, loc, colA, token);
  token = genDeallocMemRef(rewriter, loc, valA, token);
  token = genDeallocMemRef(rewriter, loc, buffer, token);
  if (!isZeroCopy)
    token = genDeallocMemRef(rewriter, loc, vecX, token);
  token = genCopyMemRef(rewriter, loc, memY, vecY, token);
  token = genDeallocMemRef(rewriter, loc, vecY, token);
  tokens.push_back(token);
  genBlockingWait(rewriter, loc, tokens);
  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
    genHostUnregisterMemref(rewriter, loc, castR);
    if (memC)
      genHostUnregisterMemref(rewriter, loc, castC);
    genHostUnregisterMemref(rewriter, loc, castV);
    genHostUnregisterMemref(rewriter, loc, castX);
    genHostUnregisterMemref(rewriter, loc, castY);
  }
  tokens.clear();

  // Done.
@@ -559,14 +587,18 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
}

/// Match and rewrite SpMM kernel.
static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
                                 linalg::GenericOp op, bool enableRT) {
static LogicalResult
rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
            GPUDataTransferStrategy gpuDataTransferStrategy) {
  Location loc = op.getLoc();
  Value a = op.getOperand(0);
  Value b = op.getOperand(1);
  Value c = op.getOperand(2); // we have C = AB
  SmallVector<Value> tokens;

  bool isZeroCopy =
      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;

  // Only admissible sparse matrix format and dense matrices.
  bool isCOO = false;
  SparseTensorType aTp = getSparseTensorType(a);
@@ -586,12 +618,27 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
  Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
  Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
  Value memV = genToValues(rewriter, loc, a);
  Value bufB, bufC;
  Value castR, castC, castV, castB, castBufC;
  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
    bufB = genTensorToMemref(rewriter, loc, b);
    bufC = genTensorToMemref(rewriter, loc, c);
    castR = genHostRegisterMemref(rewriter, loc, memR);
    if (memC)
      castC = genHostRegisterMemref(rewriter, loc, memC);
    castV = genHostRegisterMemref(rewriter, loc, memV);
    castB = genHostRegisterMemref(rewriter, loc, bufB);
    castBufC = genHostRegisterMemref(rewriter, loc, bufC);
  }

  Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
  Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
  Value valA = genAllocCopy(rewriter, loc, memV, tokens);
  Value bufB = genTensorToMemref(rewriter, loc, b);
  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
  Value bufC = genTensorToMemref(rewriter, loc, c);
  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
    bufB = genTensorToMemref(rewriter, loc, b);
  Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
  if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
    bufC = genTensorToMemref(rewriter, loc, c);
  Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
  genBlockingWait(rewriter, loc, tokens);
  tokens.clear();
@@ -649,11 +696,20 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
    token = genDeallocMemRef(rewriter, loc, colA, token);
  token = genDeallocMemRef(rewriter, loc, valA, token);
  token = genDeallocMemRef(rewriter, loc, buffer, token);
  if (!isZeroCopy)
    token = genDeallocMemRef(rewriter, loc, matB, token);
  token = genCopyMemRef(rewriter, loc, bufC, matC, token);
  token = genDeallocMemRef(rewriter, loc, matC, token);
  tokens.push_back(token);
  genBlockingWait(rewriter, loc, tokens);
  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
    genHostUnregisterMemref(rewriter, loc, castR);
    if (memC)
      genHostUnregisterMemref(rewriter, loc, castC);
    genHostUnregisterMemref(rewriter, loc, castV);
    genHostUnregisterMemref(rewriter, loc, castB);
    genHostUnregisterMemref(rewriter, loc, castC);
  }
  tokens.clear();

  // Done.
@@ -662,23 +718,41 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
}

// Match and rewrite 2:4 SpMM kernels.
static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
                                     linalg::GenericOp op) {
static LogicalResult
rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
                GPUDataTransferStrategy gpuDataTransferStrategy) {
  Location loc = op.getLoc();
  Value A = op.getOperand(0);
  Value B = op.getOperand(1);
  Value C = op.getOperand(2); // we have C = AB
  SmallVector<Value> tokens;

  bool isZeroCopy =
      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;

  // All input should be dense tensors.
  if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
    return failure();

  Value matA, matB;
  Value bufA = genTensorToMemref(rewriter, loc, A);
  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
  if (!isZeroCopy)
    matA = genAllocCopy(rewriter, loc, bufA, tokens);
  Value bufB = genTensorToMemref(rewriter, loc, B);
  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
  if (!isZeroCopy)
    matB = genAllocCopy(rewriter, loc, bufB, tokens);
  Value bufC = genTensorToMemref(rewriter, loc, C);
  Value castA, castB, castC;
  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
    castA = genHostRegisterMemref(rewriter, loc, bufA);
    castB = genHostRegisterMemref(rewriter, loc, bufB);
    castC = genHostRegisterMemref(rewriter, loc, bufC);
  }

  if (isZeroCopy) {
    matA = bufA;
    matB = bufB;
  }
  Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
  genBlockingWait(rewriter, loc, tokens);
  tokens.clear();
@@ -754,26 +828,38 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
  token = genDeallocMemRef(rewriter, loc, buffer, token);
  token = genDeallocMemRef(rewriter, loc, buffer2, token);
  token = genDeallocMemRef(rewriter, loc, buffer3, token);

  if (!isZeroCopy)
    token = genDeallocMemRef(rewriter, loc, matA, token);
  if (!isZeroCopy)
    token = genDeallocMemRef(rewriter, loc, matB, token);
  token = genCopyMemRef(rewriter, loc, bufC, matC, token);
  token = genDeallocMemRef(rewriter, loc, matC, token);
  tokens.push_back(token);
  genBlockingWait(rewriter, loc, tokens);
  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
    genHostUnregisterMemref(rewriter, loc, castA);
    genHostUnregisterMemref(rewriter, loc, castB);
    genHostUnregisterMemref(rewriter, loc, castC);
  }
  tokens.clear();
  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
  return success();
}

/// Match and rewrite SDDMM kernel.
static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
                                  linalg::GenericOp op, bool enableRT) {
static LogicalResult
rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
             GPUDataTransferStrategy gpuDataTransferStrategy) {
  Location loc = op.getLoc();
  Value a = op.getOperand(0);
  Value b = op.getOperand(1);
  Value c = op.getOperand(2);
  SmallVector<Value> tokens;

  bool isZeroCopy =
      gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;

  // Only admissible sparse matrix format and dense matrices, no COO.
  bool isCOO = false;
  SparseTensorType aTp = getSparseTensorType(a);
@@ -793,13 +879,31 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
  Value matA, matB;
  Value bufA = genTensorToMemref(rewriter, loc, a);
  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
  if (!isZeroCopy)
    matA = genAllocCopy(rewriter, loc, bufA, tokens);
  Value bufB = genTensorToMemref(rewriter, loc, b);
  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
  if (!isZeroCopy)
    matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
  Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT);
  Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT);
  Value memV = genToValues(rewriter, loc, c);

  Value castB, castA, castR, castC, castV;
  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
    castB = genHostRegisterMemref(rewriter, loc, bufB);
    castA = genHostRegisterMemref(rewriter, loc, bufA);
    castR = genHostRegisterMemref(rewriter, loc, memR);
    if (memC)
      castC = genHostRegisterMemref(rewriter, loc, memC);
    castV = genHostRegisterMemref(rewriter, loc, memV);
  }

  if (isZeroCopy) {
    matA = bufA;
    matB = bufB;
  }
  Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
  Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
  Value valC = genAllocCopy(rewriter, loc, memV, tokens);
@@ -850,8 +954,10 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
              .getAsyncToken();
  token = genDeallocMemRef(rewriter, loc, buffer, token);
  if (!isZeroCopy) {
    token = genDeallocMemRef(rewriter, loc, matA, token);
    token = genDeallocMemRef(rewriter, loc, matB, token);
  }
  token = genDeallocMemRef(rewriter, loc, rowC, token);
  if (colC)
    token = genDeallocMemRef(rewriter, loc, colC, token);
@@ -859,6 +965,14 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
  token = genDeallocMemRef(rewriter, loc, valC, token);
  tokens.push_back(token);
  genBlockingWait(rewriter, loc, tokens);
  if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
    genHostUnregisterMemref(rewriter, loc, castB);
    genHostUnregisterMemref(rewriter, loc, castA);
    genHostUnregisterMemref(rewriter, loc, castR);
    if (memC)
      genHostUnregisterMemref(rewriter, loc, castC);
    genHostUnregisterMemref(rewriter, loc, castV);
  }
  tokens.clear();

  // Done.
@@ -977,8 +1091,8 @@ private:
struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;

  LinalgOpRewriter(MLIRContext *context, bool rt)
      : OpRewritePattern(context), enableRT(rt) {}
  LinalgOpRewriter(MLIRContext *context, bool rt, GPUDataTransferStrategy t)
      : OpRewritePattern(context), enableRT(rt), gpuDataTransferStrategy(t) {}

  LogicalResult matchAndRewrite(linalg::GenericOp op,
                                PatternRewriter &rewriter) const override {
@@ -1004,7 +1118,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
        linalg::isReductionIterator(iteratorTypes[1]) &&
        // TODO: add transposed {i, j}
        maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
      return rewriteSpMV(rewriter, op, enableRT);
      return rewriteSpMV(rewriter, op, enableRT, gpuDataTransferStrategy);
    }

    // Recognize a SpMM kernel.
@@ -1016,9 +1130,9 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
        // TODO: maybe add transposed {i, j} in future
        maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
      if (op->getAttr("DENSE24"))
        return rewrite2To4SpMM(rewriter, op);
        return rewrite2To4SpMM(rewriter, op, gpuDataTransferStrategy);

      return rewriteSpMM(rewriter, op, enableRT);
      return rewriteSpMM(rewriter, op, enableRT, gpuDataTransferStrategy);
    }

    // Recognize a SDDMM kernel.
@@ -1030,7 +1144,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
        // TODO: maybe add transposed {i, j} in future
        maps == infer({{i, k}, {k, j}, {i, j}}) &&
        matchSumReductionOfMulUnary(op)) {
      return rewriteSDDMM(rewriter, op, enableRT);
      return rewriteSDDMM(rewriter, op, enableRT, gpuDataTransferStrategy);
    }

    return failure();
@@ -1038,6 +1152,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {

private:
  bool enableRT;
  GPUDataTransferStrategy gpuDataTransferStrategy;
};

} // namespace
@@ -1057,7 +1172,9 @@ void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
  patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
}

void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
                                           bool enableRT) {
  patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
void mlir::populateSparseGPULibgenPatterns(
    RewritePatternSet &patterns, bool enableRT,
    GPUDataTransferStrategy gpuDataTransfer) {
  patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT,
                                 gpuDataTransfer);
}
+9 −3
Original line number Diff line number Diff line
@@ -65,6 +65,7 @@ struct SparsificationPass
  SparsificationPass(const SparsificationPass &pass) = default;
  SparsificationPass(const SparsificationOptions &options) {
    parallelization = options.parallelizationStrategy;
    gpuDataTransfer = options.gpuDataTransferStrategy;
    enableIndexReduction = options.enableIndexReduction;
    enableGPULibgen = options.enableGPULibgen;
    enableRuntimeLibrary = options.enableRuntimeLibrary;
@@ -73,12 +74,17 @@ struct SparsificationPass
  void runOnOperation() override {
    auto *ctx = &getContext();
    // Translate strategy flags to strategy options.
    SparsificationOptions options(parallelization, enableIndexReduction,
                                  enableGPULibgen, enableRuntimeLibrary);
    SparsificationOptions options(parallelization, gpuDataTransfer,
                                  enableIndexReduction, enableGPULibgen,
                                  enableRuntimeLibrary);
    // Apply GPU libgen (if requested), sparsification, and cleanup rewriting.
    RewritePatternSet patterns(ctx);
    if (enableGPULibgen) {
      populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
      // TODO : Zero copy is disabled due to correctness bugs.Tracker #64316
      assert(gpuDataTransfer != GPUDataTransferStrategy::kZeroCopy &&
             "zero-copy transfer not supported with GPU libgen");
      populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary,
                                      gpuDataTransfer);
    }
    populateSparsificationPatterns(patterns, options);
    scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
Loading