Unverified Commit 2ae37be4 authored by bjacob's avatar bjacob Committed by GitHub
Browse files

Allow empty dimension arrays in `linalg::inferContractionDims` (#69496)

This function was returning failure when any of the intersection sets
was empty, but this is actually legitimate in "matrix times vector"
cases, where some of the operands have lower dimensionality, implying
unit-dimension semantics for the "missing" dimensions.

Example:

```mlir
func.func @transpose_extend_batch_matmul(
    %vec: tensor<32x128xi16>,
    %mat: tensor<11008x32x128xi4>) -> tensor<11008x32xi32> {
  %c0_i32 = arith.constant 0 : i32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<11008x32xi32>
  %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
  %2 = tensor.empty() : tensor<11008xf32>
  %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<11008xf32>) -> tensor<11008xf32>
  %batch_matmul_result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1)>], 
                                         iterator_types = ["parallel", "parallel", "reduction"]} 
                                         ins(%vec, %mat : tensor<32x128xi16>, tensor<11008x32x128xi4>) 
                                         outs(%1 : tensor<11008x32xi32>) {
  ^bb0(%in: i16, %in_3: i4, %out: i32):
      %19 = arith.extsi %in : i16 to i32
      %20 = arith.extui %in_3 : i4 to i32
      %21 = arith.muli %19, %20 : i32
      %22 = arith.addi %21, %out : i32
      linalg.yield %22 : i32
  } -> tensor<11008x32xi32>
  return %batch_matmul_result : tensor<11008x32xi32>
}
```

Here, we were returning failure because `ac` is empty. With this PR, we
return this useful information:

```
batch: [ 1 ]
m: [ ]
n: [ 0 ]
k: [ 2 ]
```
parent fdfe0b09
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -227,9 +227,6 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
      linalgOp, linalgOp.getDpsInputOperand(1), red);
  llvm::set_intersect(ra, rb);

  if (ac.empty() || bc.empty() || ra.empty())
    return failure();

  // Return each set in sorted order.
  ContractionDimensions dimensions{
      SmallVector<unsigned, 2>(batches.begin(), batches.end()),
+13 −0
Original line number Diff line number Diff line
@@ -910,6 +910,19 @@ module attributes { transform.target_tag = "start_here" } {
    return %result : tensor<10x15xf64>
  }

  func.func @vecmat_simple(%lhs: tensor<20xf32>, %rhs: tensor<20x15xf32>) -> tensor<15xf64> {
    %cst = arith.constant 0.0 : f64
    %empty = tensor.empty() : tensor<15xf64>
    %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<15xf64>) -> tensor<15xf64>
    // expected-remark @below {{contraction}}
    // expected-remark @below {{batch dims}}
    // expected-remark @below {{m dims}}
    // expected-remark @below {{n dims 0}}
    // expected-remark @below {{k dims 1}}
    %result = linalg.vecmat ins(%lhs, %rhs: tensor<20xf32>, tensor<20x15xf32>) outs(%fill: tensor<15xf64>) -> tensor<15xf64>
    return %result : tensor<15xf64>
  }

  func.func @double_batch(%lhs: tensor<40x10x50x20xf32>, %rhs: tensor<40x20x50x15xf32>) -> tensor<40x10x50x15xf32> {
    %cst = arith.constant 0.0 : f32
    %empty = tensor.empty() : tensor<40x10x50x15xf32>