Commit d4db5289 authored by Peiming Liu's avatar Peiming Liu
Browse files

[mlir][sparse] extend unpack operation to support unpacking a batched COO type

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D149103
parent f9fbda71
Loading
Loading
Loading
Loading
+39 −5
Original line number Diff line number Diff line
@@ -124,9 +124,10 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
}

def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
    Arguments<(ins AnySparseTensor:$tensor)>,
    Results<(outs 1DTensorOf<[AnyType]>:$values,
                  2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
    Arguments<(ins AnySparseTensor:$tensor,
                   OptionalAttr<IndexAttr>:$batched_lvls)>,
    Results<(outs TensorOf<[AnyType]>:$values,
                  TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
                  AnySignlessIntegerOrIndex:$nse)> {
  let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";

@@ -159,11 +160,44 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
    // %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
    // %nse = 3
    ```

    If `batched_lvls` is provided, the operation unpacks each batch of the tensors
    separately. The returned `nse` is the maximum nse of all batches. For a batch with
    a smaller nse, trailing zeros are appended in the result.
    Example:

    ```mlir
    // input BCOO format |1.1, 2.2, 3.3, 0.0|
    //      of 2x4 matrix |0.0, 1.2, 2.3, 0.0|
    %values, %coordinates, %nse = sparse_tensor.unpack %st batched_lvls=1
        : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO>
    // %values      = arith.constant dense<[[ 1.1,   2.2,   3.3 ],
    //                                      [ 1.2,   2.3,   0.0 ]]> : tensor<2x3xf64>
    // %coordinates = arith.constant dense<[[ [0],   [1],   [2] ],
    //                                      [ [1],   [2],   [0] ]> : tensor<2x3x1xindex>
    ```
  }];

  let extraClassDeclaration = [{
    /// Returns the number of leading levels that are batched.
    unsigned getNumBatchedLvls();
  }];

  let builders = [
    OpBuilder<(ins "Type":$values, "Type":$coordinates, "Type":$nse, "Value": $tensor),
    [{
      build($_builder, $_state, values, coordinates, nse, tensor, nullptr);
    }]>,
    OpBuilder<(ins "TypeRange":$resultTypes, "Value": $tensor),
    [{
      build($_builder, $_state, resultTypes, tensor, nullptr);
    }]>
  ];


  let assemblyFormat =
    "$tensor attr-dict `:` type($tensor)"
    "`to` type($values) `,` type($coordinates) `,` type($nse)";
    "$tensor (`batched_lvls` `=` $batched_lvls^)? attr-dict `:`"
    "type($tensor) `to` type($values) `,` type($coordinates) `,` type($nse)";

  let hasVerifier = 1;
}
+5 −1
Original line number Diff line number Diff line
@@ -719,7 +719,11 @@ LogicalResult UnpackOp::verify() {
  const auto coordinatesTp = getRankedTensorType(getCoordinates());
  const auto srcTp = getSparseTensorType(getTensor());
  return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp,
                          nullptr);
                          getBatchedLvlsAttr());
}

unsigned UnpackOp::getNumBatchedLvls() {
  return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0;
}

LogicalResult ConvertOp::verify() {
+6 −3
Original line number Diff line number Diff line
@@ -153,9 +153,12 @@ struct UnpackOpInterface
    : public BufferizableOpInterface::ExternalModel<UnpackOpInterface,
                                                    sparse_tensor::UnpackOp> {
  bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
    // Similar to InsertOp, reallocation is not considered to allocate a new
    // piece of memory.
    return false;
    // We allocate and return unpacked memory if this is a batched unpack.
    // When the number of batched levels equals to zero, we reuse the
    // coordinates/values memref (and reallocation if the requested output size
    // is larger than the actual size). Similar to InsertOp, reallocation is
    // not considered to allocate a new piece of memory.
    return llvm::cast<UnpackOp>(op).getNumBatchedLvls() != 0;
  }

  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+12 −0
Original line number Diff line number Diff line
@@ -213,6 +213,18 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
  return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}

Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
                                  Value s) {
  Value load = builder.create<memref::LoadOp>(loc, mem, s);
  if (!load.getType().isa<IndexType>()) {
    if (load.getType().getIntOrFloatBitWidth() < 64)
      load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
    load =
        builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
  }
  return load;
}

mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
  if (tp.isa<FloatType>())
    return builder.getFloatAttr(tp, 1.0);
+5 −0
Original line number Diff line number Diff line
@@ -75,6 +75,11 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
/// Add type casting between arith and index types when needed.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);

/// Generates a pointer/index load from the sparse storage scheme. Narrower
/// data types need to be zero extended before casting the value into the
/// index type used for looping and indexing.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s);

/// Generates a 1-valued attribute of the given type.  This supports
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
/// for unsupported types we raise `llvm_unreachable` rather than
Loading